[RELAND] Fix Dispatching not considering List[Optional[Tensor]] for dispatch (#68073)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68073

Relanding the original PR. Its body was as follows:

Followup to https://github.com/pytorch/pytorch/pull/60787

It turns out that the original PR was wrong for unboxed kernels. We
recently ran into this in
https://github.com/facebookresearch/functorch/issues/124

For unboxed kernels, the correct type for a Tensor?[] argument is
actually `List<optional<Tensor>>`, not `ArrayRef<optional<Tensor>>`
ghstack-source-id: 144204580

Test Plan:
- assert that https://github.com/facebookresearch/functorch/issues/124
actually works

Reviewed By: gchanan

Differential Revision: D32313601

Pulled By: zou3519

fbshipit-source-id: 8028d5f34eecabc53d603bd54d6b6748b5db461a
diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
index 31dd098..4d2e7d0 100644
--- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
+++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
@@ -66,13 +66,18 @@
         ts = ts | x.key_set();
       }
     }
-    void operator()(at::ArrayRef<c10::optional<at::Tensor>> xs) {
-      for (const auto& x : xs) {
+    // Tensor?[] translates to this case.
+    void operator()(const c10::List<c10::optional<at::Tensor>>& xs) {
+      for (c10::optional<at::Tensor> x : xs) {
         if (x.has_value()) {
           ts = ts | x.value().key_set();
         }
       }
     }
+    void operator()(at::ArrayRef<c10::optional<at::Tensor>> xs) {
+      // Just checking that the handling of Tensor?[] didn't change.
+      TORCH_INTERNAL_ASSERT(false);
+    }
     void operator()(const at::Generator& gen) {
       if (gen.defined()) {
         ts = ts | gen.key_set();
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index f40fd53..8209908 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -325,6 +325,35 @@
         self.assertEqual(y.stride(), x.stride())
         self.assertEqual(y.storage_offset(), x.storage_offset())
 
+    def test_index_put_where_only_index_is_subclass(self) -> None:
+        called_funcs = []
+
+        class MyTensor(torch.Tensor):
+            __torch_function__ = torch._C._disabled_torch_function_impl
+            elem: torch.Tensor
+            __slots__ = ['elem']
+
+            @staticmethod
+            def __new__(cls, elem, *args, **kwargs):
+                r = torch.Tensor._make_wrapper_subclass(
+                    cls, elem.size(),
+                    dtype=elem.dtype, layout=elem.layout,
+                    device=elem.device, requires_grad=elem.requires_grad
+                )
+                r.elem = elem
+                return r
+
+            @classmethod
+            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+                called_funcs.append(func)
+                return MyTensor(torch.tensor(3))
+
+        x = torch.randn(3, 3)
+        idxs = (MyTensor(torch.tensor(0)),)
+        v = torch.randn(1)
+        res = x.index_put_(idxs, v)
+        self.assertEqual(called_funcs, [torch.ops.aten.index_put_])
+
     def test_enable_python_mode_error(self) -> None:
         with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
             with enable_python_mode(torch.Tensor):