[ATen] Use MaybeOwned<T> in at::argmin/argmax (#58338)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58338
Test Plan: CI
Reviewed By: swolchok
Differential Revision: D28458968
fbshipit-source-id: 2c759bdb9fbdbef32d804f6d8efb09fb1d2bb30a
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index 111a9c7..33046fc 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -1306,7 +1306,7 @@
}
Tensor& argmax_out(const Tensor& self, c10::optional<int64_t> dim, bool keepdim, Tensor& result) {
- Tensor in;
+ c10::MaybeOwned<Tensor> in;
if (dim) {
auto sizes = self.sizes();
zero_numel_check_dims(self, dim.value(), "argmax()");
@@ -1322,13 +1322,13 @@
}
return result;
}
- in = self;
+ in = c10::MaybeOwned<Tensor>::borrowed(self);
} else {
TORCH_CHECK_INDEX(self.numel() != 0, "argmax_out(): Expected reduction dim to be specified for input.numel() == 0.");
- in = self.reshape({-1});
+ in = c10::MaybeOwned<Tensor>::owned(self.reshape({-1}));
keepdim = false;
}
- auto itr = make_reduction("argmax", result, in, dim.value_or(0), keepdim,
+ auto itr = make_reduction("argmax", result, *in, dim.value_or(0), keepdim,
self.scalar_type(), at::kLong);
if (itr.numel() != 0) {
argmax_stub(itr.device_type(), itr);
@@ -1342,7 +1342,7 @@
}
Tensor& argmin_out(const Tensor& self, c10::optional<int64_t> dim, bool keepdim, Tensor& result) {
- Tensor in;
+ c10::MaybeOwned<Tensor> in;
if (dim) {
auto sizes = self.sizes();
zero_numel_check_dims(self, dim.value(), "argmin()");
@@ -1358,13 +1358,13 @@
}
return result;
}
- in = self;
+ in = c10::MaybeOwned<Tensor>::borrowed(self);
} else {
TORCH_CHECK_INDEX(self.numel() != 0, "argmin_out(): Expected reduction dim to be specified for input.numel() == 0.");
- in = self.reshape({-1});
+ in = c10::MaybeOwned<Tensor>::owned(self.reshape({-1}));
keepdim = false;
}
- auto itr = make_reduction("argmin", result, in, dim.value_or(0), keepdim,
+ auto itr = make_reduction("argmin", result, *in, dim.value_or(0), keepdim,
self.scalar_type(), at::kLong);
if (itr.numel() != 0) {
argmin_stub(itr.device_type(), itr);