[RELAND] Fix Dispatching not considering List[Optional[Tensor]] for dispatch (#68073)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68073
Relanding the original PR. Its body was as follows:
Followup to https://github.com/pytorch/pytorch/pull/60787
It turns out that the original PR was wrong for unboxed kernels. We
recently ran into this in
https://github.com/facebookresearch/functorch/issues/124
For unboxed kernels, the correct type for a Tensor?[] argument is
actually `List<optional<Tensor>>`, not `ArrayRef<optional<Tensor>>`
ghstack-source-id: 144204580
Test Plan:
- assert that https://github.com/facebookresearch/functorch/issues/124
actually works
Reviewed By: gchanan
Differential Revision: D32313601
Pulled By: zou3519
fbshipit-source-id: 8028d5f34eecabc53d603bd54d6b6748b5db461a
diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
index 31dd098..4d2e7d0 100644
--- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
+++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
@@ -66,13 +66,18 @@
ts = ts | x.key_set();
}
}
- void operator()(at::ArrayRef<c10::optional<at::Tensor>> xs) {
- for (const auto& x : xs) {
+ // Tensor?[] translates to this case.
+ void operator()(const c10::List<c10::optional<at::Tensor>>& xs) {
+ for (c10::optional<at::Tensor> x : xs) {
if (x.has_value()) {
ts = ts | x.value().key_set();
}
}
}
+ void operator()(at::ArrayRef<c10::optional<at::Tensor>> xs) {
+ // Just checking that the handling of Tensor?[] didn't change.
+ TORCH_INTERNAL_ASSERT(false);
+ }
void operator()(const at::Generator& gen) {
if (gen.defined()) {
ts = ts | gen.key_set();
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index f40fd53..8209908 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -325,6 +325,35 @@
self.assertEqual(y.stride(), x.stride())
self.assertEqual(y.storage_offset(), x.storage_offset())
+ def test_index_put_where_only_index_is_subclass(self) -> None:
+ called_funcs = []
+
+ class MyTensor(torch.Tensor):
+ __torch_function__ = torch._C._disabled_torch_function_impl
+ elem: torch.Tensor
+ __slots__ = ['elem']
+
+ @staticmethod
+ def __new__(cls, elem, *args, **kwargs):
+ r = torch.Tensor._make_wrapper_subclass(
+ cls, elem.size(),
+ dtype=elem.dtype, layout=elem.layout,
+ device=elem.device, requires_grad=elem.requires_grad
+ )
+ r.elem = elem
+ return r
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+ called_funcs.append(func)
+ return MyTensor(torch.tensor(3))
+
+ x = torch.randn(3, 3)
+ idxs = (MyTensor(torch.tensor(0)),)
+ v = torch.randn(1)
+ res = x.index_put_(idxs, v)
+ self.assertEqual(called_funcs, [torch.ops.aten.index_put_])
+
def test_enable_python_mode_error(self) -> None:
with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
with enable_python_mode(torch.Tensor):