Enable `out` OpInfo testing for `torch.where` (#121473)
And fix behavior discrepancy between CPU and CUDA by raising an error when `out.dtype` is unexpected
Fixes https://github.com/pytorch/pytorch/issues/121397
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121473
Approved by: https://github.com/Skylion007, https://github.com/albanD
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index 851bdd1..30589b5 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -508,9 +508,11 @@
Tensor& where_self_out(const Tensor& condition, const Tensor& self, const Tensor& other, Tensor& out) {
+ const auto result_type = at::native::result_type(self, other);
+ TORCH_CHECK(out.scalar_type() == result_type, "Expected out type to be ", result_type, " but got ", out.scalar_type());
+
Tensor self_, other_, condition_;
if (self.dtype() != other.dtype()) {
- auto result_type = at::native::result_type(self, other);
self_ = self.to(result_type);
other_ = other.to(result_type);
} else {
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index c6c3153..4ff7266 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -18274,12 +18274,11 @@
# Currently only the `input` is tested in gradcheck.
# If we pass `condition` first, none of the input which supports
# autograd will be tested. Hence the following lambda.
- op=lambda self, condition, other: torch.where(condition, self, other),
+ op=lambda self, condition, other, **kwargs: torch.where(condition, self, other, **kwargs),
ref=lambda self, condition, other: np.where(condition, self, other),
sample_inputs_func=sample_inputs_where,
reference_inputs_func=reference_inputs_where,
error_inputs_func=error_inputs_where,
- supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=(
@@ -22516,6 +22515,7 @@
"_refs.where",
torch_opinfo_name="where",
op=lambda self, condition, other: refs.where(condition, self, other),
+ supports_out=False,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors', device_type='cuda'),
),