[ATen] Use MaybeOwned<T> in at::argmin/argmax (#58338)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/58338

Test Plan: CI

Reviewed By: swolchok

Differential Revision: D28458968

fbshipit-source-id: 2c759bdb9fbdbef32d804f6d8efb09fb1d2bb30a
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index 111a9c7..33046fc 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -1306,7 +1306,7 @@
 }
 
 Tensor& argmax_out(const Tensor& self, c10::optional<int64_t> dim, bool keepdim, Tensor& result) {
-  Tensor in;
+  c10::MaybeOwned<Tensor> in;
   if (dim) {
     auto sizes = self.sizes();
     zero_numel_check_dims(self, dim.value(), "argmax()");
@@ -1322,13 +1322,13 @@
       }
       return result;
     }
-    in = self;
+    in = c10::MaybeOwned<Tensor>::borrowed(self);
   } else {
     TORCH_CHECK_INDEX(self.numel() != 0, "argmax_out(): Expected reduction dim to be specified for input.numel() == 0.");
-    in = self.reshape({-1});
+    in = c10::MaybeOwned<Tensor>::owned(self.reshape({-1}));
     keepdim = false;
   }
-  auto itr = make_reduction("argmax", result, in, dim.value_or(0), keepdim,
+  auto itr = make_reduction("argmax", result, *in, dim.value_or(0), keepdim,
       self.scalar_type(), at::kLong);
   if (itr.numel() != 0) {
     argmax_stub(itr.device_type(), itr);
@@ -1342,7 +1342,7 @@
 }
 
 Tensor& argmin_out(const Tensor& self, c10::optional<int64_t> dim, bool keepdim, Tensor& result) {
-  Tensor in;
+  c10::MaybeOwned<Tensor> in;
   if (dim) {
     auto sizes = self.sizes();
     zero_numel_check_dims(self, dim.value(), "argmin()");
@@ -1358,13 +1358,13 @@
       }
       return result;
     }
-    in = self;
+    in = c10::MaybeOwned<Tensor>::borrowed(self);
   } else {
     TORCH_CHECK_INDEX(self.numel() != 0, "argmin_out(): Expected reduction dim to be specified for input.numel() == 0.");
-    in = self.reshape({-1});
+    in = c10::MaybeOwned<Tensor>::owned(self.reshape({-1}));
     keepdim = false;
   }
-  auto itr = make_reduction("argmin", result, in, dim.value_or(0), keepdim,
+  auto itr = make_reduction("argmin", result, *in, dim.value_or(0), keepdim,
       self.scalar_type(), at::kLong);
   if (itr.numel() != 0) {
     argmin_stub(itr.device_type(), itr);