Enable `out` OpInfo testing for `torch.where` (#121473)

And fix behavior discrepancy between CPU and CUDA by raising an error when `out.dtype` is unexpected

Fixes https://github.com/pytorch/pytorch/issues/121397
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121473
Approved by: https://github.com/Skylion007, https://github.com/albanD
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index 851bdd1..30589b5 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -508,9 +508,11 @@
 
 
 Tensor& where_self_out(const Tensor& condition, const Tensor& self, const Tensor& other, Tensor& out) {
+  const auto result_type = at::native::result_type(self, other);
+  TORCH_CHECK(out.scalar_type() == result_type, "Expected out type to be ", result_type, " but got ", out.scalar_type());
+
   Tensor self_, other_, condition_;
   if (self.dtype() != other.dtype()) {
-    auto result_type = at::native::result_type(self, other);
     self_ = self.to(result_type);
     other_ = other.to(result_type);
   } else {
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index c6c3153..4ff7266 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -18274,12 +18274,11 @@
            # Currently only the `input` is tested in gradcheck.
            # If we pass `condition` first, none of the input which supports
            # autograd will be tested. Hence the following lambda.
-           op=lambda self, condition, other: torch.where(condition, self, other),
+           op=lambda self, condition, other, **kwargs: torch.where(condition, self, other, **kwargs),
            ref=lambda self, condition, other: np.where(condition, self, other),
            sample_inputs_func=sample_inputs_where,
            reference_inputs_func=reference_inputs_where,
            error_inputs_func=error_inputs_where,
-           supports_out=False,
            supports_forward_ad=True,
            supports_fwgrad_bwgrad=True,
            decorators=(
@@ -22516,6 +22515,7 @@
         "_refs.where",
         torch_opinfo_name="where",
         op=lambda self, condition, other: refs.where(condition, self, other),
+        supports_out=False,
         skips=(
             DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors', device_type='cuda'),
         ),