update schema to reflect aliasing behavior (#39794)

Summary:
Fixes: https://github.com/pytorch/pytorch/issues/38555

I did an audit of `native_functions.yaml` and found several functions in addition to `reshape` which were not reporting that they could alias:

```
torch.jit.script
def foo(t: torch.Tensor):
    new_value = torch.tensor(1, dtype=t.dtype, device=t.device)

    t.flatten()[0] = new_value
    t.reshape(-1)[1] = new_value
    t.view_as(t)[2] = new_value
    t.expand_as(t)[3] = new_value
    t.reshape_as(t)[4] = new_value
    t.contiguous()[5] = new_value
    t.detach()[6] = new_value

    return t
```

Currently none of the values are assigned after dead code elimination, after this PR all are. (And the JIT output matches that of eager.)

I don't think this needs to be unit tested; presumably the generic machinery already is and this just brings these ops under the same umbrella.

**BC-breaking note**: This updates the native operator schema and the aliasing rules for autograd. JIT passes will no longer incorrectly optimize mutations on graphs containing these ops, and inplace ops on the result of `flatten` will now properly be tracked in Autograd and the proper backward graph will be created.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39794

Differential Revision: D22008358

Pulled By: robieta

fbshipit-source-id: 9d3ff536e58543211e08254a75c6110f2a3b4992
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 477d2f5..3b571b7 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -728,7 +728,7 @@
   use_c10_dispatcher: full
   variants: function
 
-- func: contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor
+- func: contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)
   variants: method
 
 - func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor
@@ -1154,7 +1154,7 @@
   variants: method  # This is method-only to match the previous tensor API. In the future we could make this a function too.
   device_guard: False
 
-- func: expand_as(Tensor self, Tensor other) -> Tensor
+- func: expand_as(Tensor(a) self, Tensor other) -> Tensor(a)
   use_c10_dispatcher: full
   variants: method  # This is method-only to match the previous tensor API. In the future we could make this a function too.
   device_guard: False
@@ -1173,17 +1173,17 @@
     CPU: eye_out_cpu
     CUDA: eye_out_cuda
 
-- func: flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor
+- func: flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)
   use_c10_dispatcher: full
   variants: function, method
 
-- func: flatten.named_out_dim(Tensor self, int start_dim, int end_dim, Dimname out_dim) -> Tensor
+- func: flatten.named_out_dim(Tensor(a) self, int start_dim, int end_dim, Dimname out_dim) -> Tensor(a)
   variants: function, method
 
-- func: flatten.using_names(Tensor self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor
+- func: flatten.using_names(Tensor(a) self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor(a)
   variants: function, method
 
-- func: flatten.DimnameList(Tensor self, Dimname[] dims, Dimname out_dim) -> Tensor
+- func: flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a)
   variants: function, method
 
 - func: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
@@ -2076,7 +2076,7 @@
   use_c10_dispatcher: full
   variants: method
 
-- func: pin_memory(Tensor self) -> Tensor
+- func: pin_memory(Tensor(a) self) -> Tensor(a)
   use_c10_dispatcher: full
   variants: method
 
@@ -2225,7 +2225,7 @@
   use_c10_dispatcher: full
   variants: function, method
 
-- func: reshape(Tensor self, int[] shape) -> Tensor
+- func: reshape(Tensor(a) self, int[] shape) -> Tensor(a)
   use_c10_dispatcher: full
   variants: function, method
   device_guard: False
@@ -2236,7 +2236,7 @@
   dispatch:
     MkldnnCPU: mkldnn_reshape
 
-- func: reshape_as(Tensor self, Tensor other) -> Tensor
+- func: reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)
   use_c10_dispatcher: full
   variants: method
   device_guard: False
@@ -2385,7 +2385,7 @@
 # to false to make such changes explicitly illegal, in order to prevent users from
 # changing metadata of the detached tensor and expecting the original tensor to also
 # be updated.
-- func: detach(Tensor self) -> Tensor
+- func: detach(Tensor(a) self) -> Tensor(a)
   use_c10_dispatcher: full
   manual_kernel_registration: True
   variants: function, method
@@ -2837,7 +2837,7 @@
 - func: var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
   variants: function
 
-- func: view_as(Tensor self, Tensor other) -> Tensor
+- func: view_as(Tensor(a) self, Tensor other) -> Tensor(a)
   use_c10_dispatcher: full
   variants: method
   device_guard: False
diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py
index 602f00b..206e0b0 100644
--- a/test/backward_compatibility/check_backward_compatibility.py
+++ b/test/backward_compatibility/check_backward_compatibility.py
@@ -99,6 +99,14 @@
     ('aten::pow', datetime.date(2020, 6, 30)),
     ('prim::min', datetime.date(2020, 6, 30)),
     ('prim::max', datetime.date(2020, 6, 30)),
+    ('aten::view_as', datetime.date(2020, 6, 30)),
+    ('aten::reshape_as', datetime.date(2020, 6, 30)),
+    ('aten::pin_memory', datetime.date(2020, 6, 30)),
+    ('aten::reshape', datetime.date(2020, 6, 30)),
+    ('aten::detach', datetime.date(2020, 6, 30)),
+    ('aten::expand_as', datetime.date(2020, 6, 30)),
+    ('aten::flatten.*', datetime.date(2020, 6, 30)),
+    ('aten::contiguous', datetime.date(2020, 6, 30)),
     ('aten::to_here', datetime.date(2020, 6, 30)),
     ('aten::to_here(RRef(t) self, double timeout*)', datetime.date(2020, 6, 30)),
     ('aten::local_value', datetime.date(2020, 6, 30)),
diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py
index 645e914..b5bd05c 100644
--- a/tools/autograd/gen_autograd.py
+++ b/tools/autograd/gen_autograd.py
@@ -54,6 +54,7 @@
     'transpose': 'self',
     'unfold': 'self',
     'unsqueeze': 'self',
+    'flatten': 'self',
     'view': 'self',
     'unbind': 'self',
     '_indices': 'self',
@@ -73,8 +74,12 @@
 # note: some VIEW_FUNCTIONS are just compositions of the view functions above
 # this list contains both the root view functions and any that are purely composed
 # of viewing functions, and is used by the JIT to determine when an operator
-# returns a view of its inputs
-RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union({'chunk', 'split', 'real', 'imag'})
+# may return a view of its inputs; however they may sometimes return a copy.
+# (e.g. `contiguous`)
+RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union({
+    'chunk', 'split', 'detach', 'contiguous', 'reshape', 'reshape_as',
+    'expand_as', 'view_as', 'real', 'imag',
+})
 
 def format_return_type(returns):
     if len(returns) == 0:
diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp
index 794689a..ccebbd7 100644
--- a/torch/csrc/autograd/VariableTypeManual.cpp
+++ b/torch/csrc/autograd/VariableTypeManual.cpp
@@ -367,7 +367,7 @@
     .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)
     .impl_unboxedOnlyKernel<decltype(VariableType::resize_as_), &VariableType::resize_as_>(DispatchKey::Autograd))
   .op(torch::RegisterOperators::options()
-    .schema("aten::detach(Tensor self) -> Tensor")
+    .schema("aten::detach(Tensor(a) self) -> Tensor(a)")
     .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)
     .kernel<decltype(VariableType::detach), &VariableType::detach>(DispatchKey::Autograd))
   .op(torch::RegisterOperators::options()
diff --git a/torch/csrc/jit/passes/requires_grad_analysis.cpp b/torch/csrc/jit/passes/requires_grad_analysis.cpp
index a8afbe6..ab41d9e 100644
--- a/torch/csrc/jit/passes/requires_grad_analysis.cpp
+++ b/torch/csrc/jit/passes/requires_grad_analysis.cpp
@@ -64,7 +64,7 @@
   } else if (node->matches(
                  "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
     return setRequiresGrad(node->output(), node->input(0)->requires_grad());
-  } else if (node->matches("aten::detach(Tensor self) -> Tensor")) {
+  } else if (node->matches("aten::detach(Tensor(a) self) -> Tensor(a)")) {
     return setRequiresGrad(node->output(), false);
   } else if (node->kind() == aten::tensor) {
     if (auto grad_index =
diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp
index b624f97..a1f6a10 100644
--- a/torch/csrc/jit/passes/shape_analysis.cpp
+++ b/torch/csrc/jit/passes/shape_analysis.cpp
@@ -803,7 +803,7 @@
             "aten::atan(Tensor self) -> Tensor",
             "aten::ceil(Tensor self) -> Tensor",
             "aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
-            "aten::contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor",
+            "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)",
             "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
             "aten::celu(Tensor self, Scalar alpha) -> Tensor",
             "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
@@ -841,7 +841,7 @@
             "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
             "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
             "aten::permute(Tensor self, int[] dims) -> Tensor",
-            "aten::pin_memory(Tensor self) -> Tensor",
+            "aten::pin_memory(Tensor(a) self) -> Tensor(a)",
             "aten::pinverse(Tensor self, float rcond) -> Tensor",
             "aten::reciprocal(Tensor self) -> Tensor",
             "aten::relu(Tensor self) -> Tensor",
@@ -1577,7 +1577,7 @@
         node->output()->setType(type->withDim(1));
         return true;
       }
-    } else if (node->matches("aten::detach(Tensor self) -> Tensor")) {
+    } else if (node->matches("aten::detach(Tensor(a) self) -> Tensor(a)")) {
       if (auto type = input_type(0)) {
         node->output()->setType(type->withRequiresGrad(false));
         return true;
@@ -1704,11 +1704,12 @@
         return tensor_types.at(0)->withScalarType(
             tensor_types.at(1)->scalarType());
       } else if (
-          node->matches("aten::view_as(Tensor self, Tensor other) -> Tensor") ||
           node->matches(
-              "aten::expand_as(Tensor self, Tensor other) -> Tensor") ||
+              "aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a)") ||
           node->matches(
-              "aten::reshape_as(Tensor self, Tensor other) -> Tensor")) {
+              "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)") ||
+          node->matches(
+              "aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)")) {
         return tensor_types.at(0)->withDim(tensor_types.at(1)->dim());
       } else if (
           node->matches("aten::view(Tensor self, int[] size) -> Tensor") ||
@@ -1753,8 +1754,9 @@
           }
         }
         return nullptr;
-      } else if (node->matches(
-                     "aten::reshape(Tensor self, int[] shape) -> Tensor")) {
+      } else if (
+          node->matches(
+              "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)")) {
         return reshape_prop(node, attr::shape, tensor_types);
       } else if (node->matches(
                      "aten::repeat(Tensor self, int[] repeats) -> Tensor")) {