Port `all` and `any` full reductions to structured kernels. (#64642)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64642
Tracking issue: #55070
This PR creates out overloads for both `all` and `any` kernels (full reduction overload),
and ports them to structured kernels.
Test Plan: Imported from OSS
Reviewed By: ngimel
Differential Revision: D30867354
Pulled By: ezyang
fbshipit-source-id: 46bccaf6c94a09ed77cc6c724d1183c82f801751
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index 9d3ef2b..3674a35 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -92,13 +92,7 @@
}
}
-void check_all_any(const char* name, const Tensor& self, const Tensor& result) {
- // Refer [all, any : uint8 compatibility]
- TORCH_CHECK(
- self.layout() == Layout::Strided,
- name, " only supports strided layout, got: ",
- self.layout());
-
+void check_result_is_bytebool(const char* name, const Tensor& self, const Tensor& result) {
if (result.defined()) {
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(
@@ -109,20 +103,42 @@
}
}
+// Note [all, any : uint8 compatibility]:
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// For NumPy comptability, `all` and `any` return
+// Tensor of dtype `bool`. However for compatibility reason,
+// for `uint8`, they return Tensor of same dtype `uint8`.
+// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
+static void allany_meta(
+ impl::MetaBase& meta,
+ const char* name,
+ const Tensor& self,
+ IntArrayRef dims,
+ bool keepdim) {
+ const auto& result = meta.maybe_get_output();
+ check_result_is_bytebool(name, self, result);
+ auto out_dtype = get_result_or_bytebool_dtype(self, result);
+ resize_reduction(meta, self, dims, keepdim, out_dtype);
+}
+
TORCH_PRECOMPUTE_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
- check_all_any("all", self, maybe_get_output());
- auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
- resize_reduction(*this, self, dim, keepdim, out_dtype);
+ allany_meta(*this, "all", self, dim, keepdim);
return TORCH_PRECOMPUTE_STRUCT2(all, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
}
+TORCH_META_FUNC(all)(const Tensor& self) {
+ allany_meta(*this, "all", self, {}, false);
+}
+
TORCH_PRECOMPUTE_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
- check_all_any("any", self, maybe_get_output());
- auto out_dtype = get_result_or_bytebool_dtype(self, maybe_get_output());
- resize_reduction(*this, self, dim, keepdim, out_dtype);
+ allany_meta(*this, "any", self, dim, keepdim);
return TORCH_PRECOMPUTE_STRUCT2(any, dim)().set_dim(maybe_wrap_dim(dim, self.dim()));
}
+TORCH_META_FUNC(any)(const Tensor& self) {
+ allany_meta(*this, "any", self, {}, false);
+}
+
void check_argmax_argmin(
const char* name,
const Tensor& self,
@@ -1323,22 +1339,6 @@
return at::norm(self, p, IntArrayRef{}, false);
}
-// Note [all, any : uint8 compatibility]:
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-// For NumPy comptability, `all` and `any` return
-// Tensor of dtype `bool`. However for compatibility reason,
-// for `uint8`, they return Tensor of same dtype `uint8`.
-// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
-inline const Tensor & _all(const Tensor & result, TensorIterator & iter) {
- if (iter.numel() == 0) {
- result.fill_(1);
- } else {
- and_stub(iter.device_type(), iter);
- }
-
- return result;
-}
-
inline TensorIterator get_allany_iter(
const Tensor& self,
const Tensor& result,
@@ -1355,61 +1355,39 @@
self, result, dims, keepdim, result.scalar_type());
}
-Tensor all(const Tensor& self) {
- Tensor result;
-
- meta::check_all_any("all", self, result);
- auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
- auto shape = meta::get_reduction_shape(self, {}, false);
-
- result = at::empty(shape, self.options().dtype(out_dtype));
- auto iter = get_allany_iter(self, result, {}, false);
-
- return _all(result, iter);
+template <int identity, typename Stub>
+inline void allany_impl(
+ const Tensor& self,
+ const Tensor& result,
+ IntArrayRef dims,
+ bool keepdim,
+ Stub& stub) {
+ if (self.numel() == 0) {
+ result.fill_(identity);
+ } else if (self.numel() == 1) {
+ result.fill_(self.item().toBool());
+ } else {
+ auto iter = get_allany_iter(self, result, dims, keepdim);
+ stub(iter.device_type(), iter);
+ }
}
TORCH_IMPL_FUNC(all_out)
(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
- auto iter = get_allany_iter(self, result, dim, keepdim);
- auto mut_result = const_cast<Tensor&>(result);
- if (!_dimreduce_return_trivial(mut_result, self, 1, dim, keepdim)) {
- _all(mut_result, iter);
- }
+ allany_impl<1>(self, result, dim, keepdim, and_stub);
}
-inline const Tensor & _any(const Tensor & result, TensorIterator & iter) {
- if (iter.numel() == 0) {
- result.fill_(0);
- } else {
- or_stub(iter.device_type(), iter);
- }
-
- return result;
-}
-
-Tensor any(const Tensor& self) {
- Tensor result;
-
- meta::check_all_any("any", self, result);
- auto out_dtype = meta::get_result_or_bytebool_dtype(self, result);
- auto shape = meta::get_reduction_shape(self, {}, false);
-
- result = at::empty(shape, self.options().dtype(out_dtype));
- auto iter = get_allany_iter(self, result, {}, false);
-
- return _any(result, iter);
+TORCH_IMPL_FUNC(all_all_out)(const Tensor& self, const Tensor& result) {
+ allany_impl<1>(self, result, {}, false, and_stub);
}
TORCH_IMPL_FUNC(any_out)
-(const Tensor& self,
- int64_t dim,
- bool keepdim,
- const Tensor& result) {
- auto iter = get_allany_iter(self, result, dim, keepdim);
- auto mut_result = const_cast<Tensor&>(result);
- if (!_dimreduce_return_trivial(mut_result, self, 0, dim, keepdim)) {
- _any(mut_result, iter);
- }
+(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
+ allany_impl<0>(self, result, dim, keepdim, or_stub);
+}
+
+TORCH_IMPL_FUNC(any_all_out)(const Tensor& self, const Tensor& result) {
+ allany_impl<0>(self, result, {}, false, or_stub);
}
Tensor &amin_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tensor& result) {
diff --git a/aten/src/ATen/native/ReduceOpsUtils.h b/aten/src/ATen/native/ReduceOpsUtils.h
index e177718..a357174 100644
--- a/aten/src/ATen/native/ReduceOpsUtils.h
+++ b/aten/src/ATen/native/ReduceOpsUtils.h
@@ -51,17 +51,16 @@
return src.as_strided(replacement_shape, strides);
}
-inline Tensor &_dimreduce_setup(Tensor &result, const Tensor &self,
+inline void _dimreduce_setup(const Tensor &result, const Tensor &self,
int64_t dim) {
IntArrayRef self_sizes = self.sizes();
std::vector<int64_t> result_sizes;
result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
result_sizes[dim] = 1;
result.resize_(result_sizes);
- return result;
}
-inline bool _dimreduce_return_trivial(Tensor &result, const Tensor &self,
+inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self,
const Scalar& ident, int64_t dim, bool keepdim) {
if (self.numel() == 1 && self.ndimension() == 0) {
result.resize_({});
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 16c4b06..edcccaa 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -7375,17 +7375,28 @@
- func: all(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
+ structured_delegate: all.all_out
variants: method, function
+
+- func: all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+ device_check: NoCheck
+ structured: True
dispatch:
- CPU, CUDA: all
+ CPU, CUDA: all_all_out
- func: any(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
+ structured_delegate: any.all_out
variants: method, function
dispatch:
- CPU, CUDA: any
SparseCPU, SparseCUDA: any_sparse
+- func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+ device_check: NoCheck
+ structured: True
+ dispatch:
+ CPU, CUDA: any_all_out
+
- func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True