JIT: Eliminate SumToSize by using Optional Lists (#18697)

Summary:
This PR is a eliminates unneeded grad_sum_to_size and in particular speeds up the LSTM backward by allowing better fusion.

It consists of two parts:
- In AutoDiff, record broadcasting sizes only if the broadcast output size is different from the input size, otherwise record None.
- The specialization of Optional arguments (#18407) allows us to then eliminate ` _grad_sum_to_size(t, None)` in the peephole optimization   step.

Thus, in the LSTM case, no SumToSize remain in the crucial fusion group. The trick here is that we can specialize on the runtime information from the forward.

I'm testing that different broadcasting situations lead to different graphs.

I didn't move all symbolic_script _grad_sum_to_size to the new logic, but it might be better to do this incrementally, anyway.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18697

Differential Revision: D15482076

Pulled By: wanchaol

fbshipit-source-id: 7f89367e35b8729910077c95c02bccefc8678afb
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index 6bd3c4b..1f1d1f4 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -83,6 +83,7 @@
   _(prim, abs)                     \
   _(prim, rangelist)               \
   _(aten, _grad_sum_to_size)       \
+  _(aten, _size_if_not_equal)      \
   _(aten, _ncf_unsqueeze)          \
   _(aten, warn)                    \
   _(aten, floordiv)                \
diff --git a/test/cpp/jit/test_autodiff.h b/test/cpp/jit/test_autodiff.h
index 3da4047..d8d2ba4 100644
--- a/test/cpp/jit/test_autodiff.h
+++ b/test/cpp/jit/test_autodiff.h
@@ -182,7 +182,7 @@
 
   auto grad_spec = differentiate(graph);
   std::vector<size_t> expected_captured_inputs = {0, 1};
-  std::vector<size_t> expected_captured_outputs = {1, 2};
+  std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5, 6, 7};
   std::vector<size_t> expected_input_vjps = {0, 1};
   std::vector<size_t> expected_output_vjps = {0, 1};
   ASSERT_EQ(grad_spec.f_real_outputs, 1);
@@ -228,7 +228,9 @@
   std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
   ASSERT_EQ(grad_spec.f_real_outputs, 2);
   ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
-  ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3}));
+  ASSERT_EQ(
+      grad_spec.df_input_captured_outputs,
+      std::vector<size_t>({2, 3, 4, 5, 6}));
   ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
   ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
   testing::FileCheck()
diff --git a/test/test_jit.py b/test/test_jit.py
index bd4790a..a1099fe 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -200,6 +200,15 @@
     return grad_executors[diff_graph_idx or 0]
 
 
+def all_backward_graphs(script_module, diff_graph_idx=None):
+    # Note: for Python 2 the order seems to be unstable
+    ge_state = script_module.get_debug_state()
+    fwd_plan = get_execution_plan(ge_state)
+    grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
+    bwd_plans = list(grad_executor_state.execution_plans.values())
+    return [p.graph.copy() for p in bwd_plans]
+
+
 def backward_graph(script_module, diff_graph_idx=None):
     ge_state = script_module.get_debug_state()
     fwd_plan = get_execution_plan(ge_state)
diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py
index 1d36383..3ab2e07 100644
--- a/test/test_jit_fuser.py
+++ b/test/test_jit_fuser.py
@@ -15,7 +15,7 @@
 from itertools import product, permutations
 
 from test_jit import JitTestCase, enable_cpu_fuser, RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, \
-    backward_graph, get_lstm_inputs, get_milstm_inputs, LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
+    backward_graph, all_backward_graphs, get_lstm_inputs, get_milstm_inputs, LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
 
 
 class TestFuser(JitTestCase):
@@ -275,7 +275,7 @@
         for f, inputs in product(funcs, [[a, b], [a, nan]]):
             inp1, inp2 = inputs
             s = self.checkScript(f, (inp1, inp2))
-            self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size'})
+            self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'})
 
             c = s(inp1, inp2)
             c.sum().backward()
@@ -350,7 +350,8 @@
         self.assertAllFused(ge.graph_for(x, y))
         x.requires_grad_(True)
         y.requires_grad_(True)
-        self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes"))
+        self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
+                                                            "aten::_size_if_not_equal"))
 
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -522,7 +523,8 @@
         self.assertAllFused(scripted.graph_for(x, p))
         x.requires_grad_(True)
         out = scripted(x, p)
-        self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes"))
+        self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes",
+                                                                  "aten::_size_if_not_equal"))
 
     @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
     @enable_cpu_fuser
@@ -535,7 +537,7 @@
         b = torch.randn(5, 5, requires_grad=True)
         a = torch.randn(5, 5, requires_grad=True)
         s = self.checkScript(f, (a, b))
-        self.assertAllFused(s.graph_for(a, b), except_for={'aten::size'})
+        self.assertAllFused(s.graph_for(a, b), except_for={'aten::size', 'aten::_size_if_not_equal', 'prim::BroadcastSizes'})
 
         c = s(a, b)
         ga, gb = torch.autograd.grad(c.sum(), [a, b])
@@ -578,12 +580,12 @@
 
         s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2))
         self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2),
-                            except_for={'aten::size', 'prim::BroadcastSizes'})
+                            except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'})
 
         c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)
         torch.autograd.grad(c.sum(), [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2])
         graph = backward_graph(s)
-        self.assertAllFused(graph, except_for={'aten::size', 'prim::BroadcastSizes'})
+        self.assertAllFused(graph, except_for={'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal'})
 
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -670,8 +672,8 @@
         hy, cy = module(*inputs)
         (hy + cy).sum().backward()
         backward = backward_graph(module)
-        FileCheck().check("FusionGroup_0").check_next("FusionGroup_1") \
-            .check_not("FusionGroup_2").run(str(backward))
+        self.assertAllFused(backward, except_for=("aten::t", "aten::mm",
+                                                  "aten::_grad_sum_to_size"))
 
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -801,7 +803,8 @@
         ge = self.checkTrace(fn_test_erf, (x,))
         self.assertAllFused(ge.graph_for(x))
         x.requires_grad_(True)
-        self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes"))
+        self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes",
+                                                         "aten::_size_if_not_equal"))
 
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@@ -818,7 +821,8 @@
         self.assertAllFused(script_f.graph_for(x, y))
         x.requires_grad_(True)
         out = script_f(x, y)
-        self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes"))
+        self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes",
+                                                                  "aten::_size_if_not_equal"))
         # test that broadcasting random produces correct results
         x = torch.ones(4, 4, dtype=torch.float, device='cuda')
         y = torch.ones(4, dtype=torch.float, device='cuda')
@@ -894,6 +898,44 @@
         self.assertEqual(result2, expected2)
         self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
 
+
+    @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+    @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+    @skipIfRocm
+    def test_grad_sum_to_size_elimination(self):
+
+        def my_broadcasted_cell(a, b, c):
+            return (a + b) + c
+
+        s1 = torch.randn(5, 1, requires_grad=True, device='cuda')
+        s2 = torch.randn(5, 5, requires_grad=True, device='cuda')
+
+        module = self.checkScript(my_broadcasted_cell, (s1, s1, s1))
+        forward_graph = module.graph_for(s1, s1, s1)
+        self.assertAllFused(forward_graph, except_for=("aten::size", "prim::BroadcastSizes",
+                                                       "aten::_size_if_not_equal"))
+
+        old_plans = set()
+        for i in range(3):
+            # if we have s2, then the s1 are _grad_sum_to_size'd
+            args = s2 if i < 1 else s1, s2 if i < 2 else s1, s2
+            args = [a.detach_().requires_grad_() for a in args]
+            res = module(s2 if i < 1 else s1, s2 if i < 2 else s1, s2)
+            grads = torch.autograd.grad(res.sum(), args)
+            for inp, gr in zip(args, grads):
+                self.assertEqual(inp.shape, gr.shape)
+            backward = None
+            # this is a workaround for the backward graphs not being
+            # in order for Python 2
+            for g in all_backward_graphs(module):
+                if str(g) not in old_plans:
+                    assert backward is None
+                    backward = g
+                    old_plans.add(str(backward))
+            self.assertEqual(len([1 for o in backward.outputs() if o.node().kind() == "aten::_grad_sum_to_size"]), i)
+            self.assertEqual(len([1 for o in backward.outputs() if o.node().kind() == "prim::Param"]), 3 - i)
+
+
     @unittest.skipIf(not IS_WINDOWS, "Test that the fuser is disabled on Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
     def test_windows_cuda(self):
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp
index 41482c1..dd77ee8 100644
--- a/torch/csrc/jit/autodiff.cpp
+++ b/torch/csrc/jit/autodiff.cpp
@@ -213,11 +213,17 @@
  private:
   Node* node;
 
-  SymbolicVariable gradSumToSizeOf(SymbolicVariable v, Symbol input_name) {
+  SymbolicVariable gradSumToSizeOf(
+      SymbolicVariable v,
+      Symbol input_name,
+      SymbolicVariable fw_output) {
     Value* size;
     {
-      WithInsertPoint insert_guard{node};
-      size = SymbolicVariable(node->namedInput(input_name)).size();
+      // We insert after the current node because we want to use
+      // its output.
+      WithInsertPoint insert_guard{node->next()};
+      size = SymbolicVariable(node->namedInput(input_name))
+                 .size_if_not_equal(fw_output);
     }
     return v.gradSumToSize(size);
   };
@@ -237,9 +243,11 @@
 
     if (node->matches(
             "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
-      return {gradSumToSizeOf(grads.at(0), attr::self),
+      return {gradSumToSizeOf(grads.at(0), attr::self, outputs.at(0)),
               gradSumToSizeOf(
-                  grads.at(0) * node->namedInput(attr::alpha), attr::other),
+                  grads.at(0) * node->namedInput(attr::alpha),
+                  attr::other,
+                  outputs.at(0)),
               nullptr};
 
     } else if (
@@ -254,9 +262,11 @@
     } else if (
         node->matches(
             "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
-      return {gradSumToSizeOf(grads.at(0), attr::self),
+      return {gradSumToSizeOf(grads.at(0), attr::self, outputs.at(0)),
               gradSumToSizeOf(
-                  -grads.at(0) * node->namedInput(attr::alpha), attr::other),
+                  -grads.at(0) * node->namedInput(attr::alpha),
+                  attr::other,
+                  outputs.at(0)),
               nullptr};
 
     } else if (
@@ -337,7 +347,9 @@
         node->matches(
             "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
       return {gradSumToSizeOf(
-                  grads.at(0) * node->namedInput(attr::beta), attr::self),
+                  grads.at(0) * node->namedInput(attr::beta),
+                  attr::self,
+                  outputs.at(0)),
               grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
               inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
               nullptr,
diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
index 2e397bc..79db4c0 100644
--- a/torch/csrc/jit/passes/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -813,6 +813,10 @@
     // The output of producer_for_chunk_node could have been used in some
     // aten::size operators, so we need to clean those up as well (we simply
     // broadcast all its tensor inputs).
+    // We need to insert these early in the graph, i.e. immediately after
+    // the producer_for_chunk_node as we will have the _size_if_not_same
+    // that may be before the bchunk.
+    WithInsertPoint guard2(producer_for_chunk_node);
     auto size_calc_uses = producer_for_chunk_node->output()->uses();
     if (!size_calc_uses.empty()) {
       auto tensor_inputs = filter(
diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp
index ff2bd6d..2e94991 100644
--- a/torch/csrc/jit/passes/peephole.cpp
+++ b/torch/csrc/jit/passes/peephole.cpp
@@ -157,12 +157,17 @@
       }
     } else if (
         node->matches(
-            "aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)")) {
-      auto uses = node->output()->uses();
-      for (Use u : uses) {
-        if (u.user->matches(
-                "aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)")) {
-          u.user->replaceInput(0, node->inputs().at(0));
+            "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)")) {
+      if (node->input(1)->mustBeNone()) {
+        node->output()->replaceAllUsesWith(node->input(0));
+      } else {
+        auto uses = node->output()->uses();
+        for (Use u : uses) {
+          if (u.user->matches(
+                  "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)") &&
+              u.user->input(1)->type()->isSubtypeOf(ListType::ofInts())) {
+            u.user->replaceInput(0, node->inputs().at(0));
+          }
         }
       }
     } else if (node->kind() == prim::If) {
diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp
index 4a5aa4d..f189a39 100644
--- a/torch/csrc/jit/register_prim_ops.cpp
+++ b/torch/csrc/jit/register_prim_ops.cpp
@@ -655,12 +655,30 @@
            };
          }),
      Operator(
-         "aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)",
+         "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)",
          [](Stack& stack) {
-           at::Tensor self;
-           Shared<IntList> desired_sizes;
-           pop(stack, self, desired_sizes);
-           push(stack, at::sum_to(std::move(self), desired_sizes->elements()));
+           IValue self, size;
+           pop(stack, self, size);
+           if (size.isNone()) {
+             push(stack, self);
+           } else {
+             push(
+                 stack,
+                 at::sum_to(self.toTensor(), size.toIntList()->elements()));
+           }
+           return 0;
+         }),
+     Operator(
+         "aten::_size_if_not_equal(int[] self_size, int[] other_size) -> int[]?",
+         [](Stack& stack) {
+           IValue self_size, other_size;
+           pop(stack, self_size, other_size);
+           const auto s = self_size.toIntList()->elements();
+           if (s == other_size.toIntList()->elements()) {
+             push(stack, IValue());
+           } else {
+             push(stack, s);
+           }
            return 0;
          }),
      Operator(
diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp
index 1a2b27c..8b6390f 100644
--- a/torch/csrc/jit/symbolic_script.cpp
+++ b/torch/csrc/jit/symbolic_script.cpp
@@ -466,16 +466,6 @@
                 return grad_self, grad_end, None
             return torch.lerp(self, end, weight), backward
 
-        def mul(self, other):
-            def backward(grad_output):
-                # self & other are used in backward. No need to pass in their size
-                # from forward pass
-                grad_self = (grad_output * other)._grad_sum_to_size(self.size())
-                grad_other = (grad_output * self)._grad_sum_to_size(other.size())
-                return grad_self, grad_other
-
-            return self * other, backward
-
         def reshape(self,
                     shape: List[int]):
             self_size = self.size()
@@ -698,21 +688,43 @@
 
     )",
     R"(
-        def div(self, other):
+        def AD_sizes_if_not_equal_multi(t1, t2, res):
+            return torch._size_if_not_equal(t1.size(), res.size()), torch._size_if_not_equal(t2.size(), res.size())
+
+        def mul(self, other):
+            result = self * other
+            self_size, other_size = AD_sizes_if_not_equal_multi(self, other, result)
+
             def backward(grad_output):
-                grad_self = (grad_output / other)._grad_sum_to_size(self.size())
-                grad_other = (-grad_output * self / (other * other))._grad_sum_to_size(other.size())
+                # self & other are used in backward. No need to pass in their size
+                # from forward pass
+                grad_self = (grad_output * other)._grad_sum_to_size(self_size)
+                grad_other = (grad_output * self)._grad_sum_to_size(other_size)
                 return grad_self, grad_other
 
-            return self / other, backward
+            return result, backward
+
+        def div(self, other):
+            result = self / other
+            self_size, other_size = AD_sizes_if_not_equal_multi(self, other, result)
+
+            def backward(grad_output):
+                grad_self = (grad_output / other)._grad_sum_to_size(self_size)
+                grad_other = (-grad_output * self / (other * other))._grad_sum_to_size(other_size)
+                return grad_self, grad_other
+
+            return result, backward
 
         def max(self, other):
+            result = torch.max(self, other)
+            self_size, other_size = AD_sizes_if_not_equal_multi(self, other, result)
+
             def backward(grad_output):
-                grad_self = (grad_output * (self > other).type_as(grad_output))._grad_sum_to_size(self.size())
-                grad_other = (grad_output * (other > self).type_as(grad_output))._grad_sum_to_size(other.size())
+                grad_self = (grad_output * (self > other).type_as(grad_output))._grad_sum_to_size(self_size)
+                grad_other = (grad_output * (other > self).type_as(grad_output))._grad_sum_to_size(other_size)
                 return grad_self, grad_other
 
-            return torch.max(self, other), backward
+            return result, backward
 
         def min(self, other):
             def backward(grad_output):
@@ -948,12 +960,21 @@
             return torch.trunc(self), backward
 
         def _grad_sum_to_size(self,
-                              size: List[int]):
-            self_size = self.size()
-            def backward(grad_output):
-                return grad_output.expand(self_size), None
+                              size: Optional[List[int]]):
+            if size is not None:
+                self_size = self.size()
+            else:
+                self_size = None
 
-            return torch._grad_sum_to_size(self, size), backward
+            result = torch._grad_sum_to_size(self, size)
+            def backward(grad_output):
+                if self_size is None:
+                    grad_input = grad_output
+                else:
+                    grad_input = grad_output.expand(self_size)
+                return grad_input, None
+
+            return result, backward
     )",
     R"(
         def AD_adaptive_avg_pool2d_backward(grad,
diff --git a/torch/csrc/jit/symbolic_variable.h b/torch/csrc/jit/symbolic_variable.h
index 8fb8feb..c1bf672 100644
--- a/torch/csrc/jit/symbolic_variable.h
+++ b/torch/csrc/jit/symbolic_variable.h
@@ -174,6 +174,10 @@
     return create(aten::type_as, {*this, rhs})[0].typeLikeWithRhsScalarType(
         *this, rhs);
   }
+  SymbolicVariable size_if_not_equal(const SymbolicVariable other) const {
+    return create(aten::_size_if_not_equal, {this->size(), other.size()})[0]
+        .toType(OptionalType::create(ListType::ofInts()));
+  }
   SymbolicVariable narrow(int dim, int64_t start, int64_t length) const {
     return create(
         t("narrow"),
@@ -306,6 +310,10 @@
       v->setType(other_type->contiguous());
     return *this;
   }
+  SymbolicVariable toType(TypePtr type) const {
+    v->setType(type);
+    return *this;
+  }
   SymbolicVariable typeLikeWithScalarType(
       SymbolicVariable other,
       at::ScalarType type) const {