Port `all` and `any` full reductions to structured kernels. (#64642)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64642

Tracking issue: #55070

This PR creates out overloads for both `all` and `any` kernels (full reduction overload),
and ports them to structured kernels.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D30867354

Pulled By: ezyang

fbshipit-source-id: 46bccaf6c94a09ed77cc6c724d1183c82f801751
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index 9d3ef2b..3674a35 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -92,13 +92,7 @@
   }
 }
 
-void check_all_any(const char* name, const Tensor& self, const Tensor& result) {
-  // Refer [all, any : uint8 compatibility]
-  TORCH_CHECK(
-      self.layout() == Layout::Strided,
-      name, " only supports strided layout, got: ",
-      self.layout());
-
+void check_result_is_bytebool(const char* name, const Tensor& self, const Tensor& result) {
   if (result.defined()) {
     // Refer [all, any : uint8 compatibility]
     TORCH_CHECK(
@@ -109,20 +103,42 @@
   }
 }
 
+// Note [all, any : uint8 compatibility]:
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// For NumPy comptability, `all` and `any` return
+// Tensor of dtype `bool`. However for compatibility reason,
+// for `uint8`, they return Tensor of same dtype `uint8`.
+// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
+static void allany_meta(
+    impl::MetaBase& meta,
+    const char* name,
+    const Tensor& self,
+    IntArrayRef dims,
+    bool keepdim) {
+  const auto& result = meta.maybe_get_output();
+  check_result_is_bytebool(name, self, result);
+  auto out_dtype = get_result_or_bytebool_dtype(self, result);
+  resize_reduction(meta, self, dims, keepdim, out_dtype);
+}
+
 TORCH_PRECOMPUTE_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
-  check_all_any("all", self, maybe_get_output());
-  auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
-  resize_reduction(*this, self, dim, keepdim, out_dtype);
+  allany_meta(*this, "all", self, dim, keepdim);
   return TORCH_PRECOMPUTE_STRUCT2(all, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
 }
 
+TORCH_META_FUNC(all)(const Tensor& self) {
+  allany_meta(*this, "all", self, {}, false);
+}
+
 TORCH_PRECOMPUTE_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
-  check_all_any("any", self, maybe_get_output());
-  auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
-  resize_reduction(*this, self, dim, keepdim, out_dtype);
+  allany_meta(*this, "any", self, dim, keepdim);
   return TORCH_PRECOMPUTE_STRUCT2(any, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
 }
 
+TORCH_META_FUNC(any)(const Tensor& self) {
+  allany_meta(*this, "any", self, {}, false);
+}
+
 void check_argmax_argmin(
     const char* name,
     const Tensor& self,
@@ -1323,22 +1339,6 @@
   return at::norm(self, p, IntArrayRef{}, false);
 }
 
-// Note [all, any : uint8 compatibility]:
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-// For NumPy comptability, `all` and `any` return
-// Tensor of dtype `bool`. However for compatibility reason,
-// for `uint8`, they return Tensor of same dtype `uint8`.
-// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
-inline const Tensor & _all(const Tensor & result, TensorIterator & iter) {
-  if (iter.numel() == 0) {
-    result.fill_(1);
-  } else {
-    and_stub(iter.device_type(), iter);
-  }
-
-  return result;
-}
-
 inline TensorIterator get_allany_iter(
     const Tensor& self,
     const Tensor& result,
@@ -1355,61 +1355,39 @@
       self, result, dims, keepdim, result.scalar_type());
 }
 
-Tensor all(const Tensor& self) {
-  Tensor result;
-
-  meta::check_all_any("all", self, result);
-  auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
-  auto shape = meta::get_reduction_shape(self, {}, false);
-
-  result = at::empty(shape, self.options().dtype(out_dtype));
-  auto iter = get_allany_iter(self, result, {}, false);
-
-  return _all(result, iter);
+template <int identity, typename Stub>
+inline void allany_impl(
+    const Tensor& self,
+    const Tensor& result,
+    IntArrayRef dims,
+    bool keepdim,
+    Stub& stub) {
+  if (self.numel() == 0) {
+    result.fill_(identity);
+  } else if (self.numel() == 1) {
+    result.fill_(self.item().toBool());
+  } else {
+    auto iter = get_allany_iter(self, result, dims, keepdim);
+    stub(iter.device_type(), iter);
+  }
 }
 
 TORCH_IMPL_FUNC(all_out)
 (const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
-  auto iter = get_allany_iter(self, result, dim, keepdim);
-  auto mut_result = const_cast<Tensor&>(result);
-  if (!_dimreduce_return_trivial(mut_result, self, 1, dim, keepdim)) {
-    _all(mut_result, iter);
-  }
+  allany_impl<1>(self, result, dim, keepdim, and_stub);
 }
 
-inline const Tensor & _any(const Tensor & result, TensorIterator & iter) {
-  if (iter.numel() == 0) {
-    result.fill_(0);
-  } else {
-    or_stub(iter.device_type(), iter);
-  }
-
-  return result;
-}
-
-Tensor any(const Tensor& self) {
-  Tensor result;
-
-  meta::check_all_any("any", self, result);
-  auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
-  auto shape = meta::get_reduction_shape(self, {}, false);
-
-  result = at::empty(shape, self.options().dtype(out_dtype));
-  auto iter = get_allany_iter(self, result, {}, false);
-
-  return _any(result, iter);
+TORCH_IMPL_FUNC(all_all_out)(const Tensor& self, const Tensor& result) {
+  allany_impl<1>(self, result, {}, false, and_stub);
 }
 
 TORCH_IMPL_FUNC(any_out)
-(const Tensor& self,
- int64_t dim,
- bool keepdim,
- const Tensor& result) {
-  auto iter = get_allany_iter(self, result, dim, keepdim);
-  auto mut_result = const_cast<Tensor&>(result);
-  if (!_dimreduce_return_trivial(mut_result, self, 0, dim, keepdim)) {
-    _any(mut_result, iter);
-  }
+(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
+  allany_impl<0>(self, result, dim, keepdim, or_stub);
+}
+
+TORCH_IMPL_FUNC(any_all_out)(const Tensor& self, const Tensor& result) {
+  allany_impl<0>(self, result, {}, false, or_stub);
 }
 
 Tensor &amin_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tensor& result) {
diff --git a/aten/src/ATen/native/ReduceOpsUtils.h b/aten/src/ATen/native/ReduceOpsUtils.h
index e177718..a357174 100644
--- a/aten/src/ATen/native/ReduceOpsUtils.h
+++ b/aten/src/ATen/native/ReduceOpsUtils.h
@@ -51,17 +51,16 @@
   return src.as_strided(replacement_shape, strides);
 }
 
-inline Tensor &_dimreduce_setup(Tensor &result, const Tensor &self,
+inline void _dimreduce_setup(const Tensor &result, const Tensor &self,
                                 int64_t dim) {
   IntArrayRef self_sizes = self.sizes();
   std::vector<int64_t> result_sizes;
   result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
   result_sizes[dim] = 1;
   result.resize_(result_sizes);
-  return result;
 }
 
-inline bool _dimreduce_return_trivial(Tensor &result, const Tensor &self,
+inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self,
                                       const Scalar& ident, int64_t dim, bool keepdim) {
   if (self.numel() == 1 && self.ndimension() == 0) {
     result.resize_({});
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 16c4b06..edcccaa 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -7375,17 +7375,28 @@
 
 - func: all(Tensor self) -> Tensor
   device_check: NoCheck   # TensorIterator
+  structured_delegate: all.all_out
   variants: method, function
+
+- func: all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  structured: True
   dispatch:
-    CPU, CUDA: all
+    CPU, CUDA: all_all_out
 
 - func: any(Tensor self) -> Tensor
   device_check: NoCheck   # TensorIterator
+  structured_delegate: any.all_out
   variants: method, function
   dispatch:
-    CPU, CUDA: any
     SparseCPU, SparseCUDA: any_sparse
 
+- func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  structured: True
+  dispatch:
+    CPU, CUDA: any_all_out
+
 - func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)
   device_check: NoCheck   # TensorIterator
   structured: True