Back out "Revert D34524207: [pytorch][PR] remove _s_where" (#73579)

Summary:
Original commit changeset: 87b1220d851c

Original Phabricator Diff: D34524207 (https://github.com/pytorch/pytorch/commit/4eb248256801103b08726bf5d85496641cebcdbb) (https://github.com/pytorch/pytorch/commit/4eb248256801103b08726bf5d85496641cebcdbb)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73579

Test Plan:
OSS tests
tested with canary https://www.internalfb.com/intern/ads/canary/441912928798660873

Reviewed By: ezyang

Differential Revision: D34688237

Pulled By: ngimel

fbshipit-source-id: 32f3a0046053ef52e95ab45a26bfc1de17e7e061
(cherry picked from commit d1c0acbe3e0ff884c429072923a468ee1d3d447d)
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index 0114deb..f5cff20 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -324,20 +324,25 @@
 }
 
 Tensor where(const Tensor& condition, const Tensor& self, const Tensor& other) {
-  TORCH_CHECK(condition.device() == self.device() && self.device() == other.device(),
-              "Expected condition, x and y to be on the same device, but condition is on ",
-              condition.device(), " and x and y are on ", self.device(), " and ", other.device(),
-              " respectively");
+  TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
 
   if (condition.scalar_type() == ScalarType::Byte) {
   TORCH_WARN_ONCE("where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.");
-} else {
+  } else {
   TORCH_CHECK(condition.scalar_type() == ScalarType::Bool, "where expected condition to be a boolean tensor, but got a tensor with dtype ", condition.scalar_type());
-}
+  }
+  Tensor cond_bool = condition.scalar_type() == ScalarType::Byte ? condition.to(ScalarType::Bool) : condition;
+  Tensor ret = at::empty({0}, self.options());
+  auto iter = at::TensorIteratorConfig()
+    .check_all_same_dtype(false)
+    .add_output(ret)
+    .add_input(cond_bool)
+    .add_input(self)
+    .add_input(other)
+    .build();
+  where_kernel(iter.device_type(), iter);
+  return ret;
 
-  c10::MaybeOwned<Tensor> b_condition, b_self, b_other;
-  std::tie(b_condition, b_self, b_other) = expand_outplace(condition, self, other, "where");
-  return at::_s_where(*b_condition, *b_self, *b_other);
 }
 
 Tensor where(const Tensor& condition, const Scalar& self, const Tensor& other) {
@@ -359,22 +364,6 @@
   return condition.nonzero_numpy();
 }
 
-Tensor _s_where(const Tensor& condition, const Tensor& self, const Tensor& other) {
-  TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
-  Tensor ret = at::empty(self.sizes(), self.options());
-  //
-  Tensor cond_bool = condition.scalar_type() == ScalarType::Byte ? condition.to(ScalarType::Bool) : condition;
-  auto iter = at::TensorIteratorConfig()
-    .check_all_same_dtype(false)
-    .add_output(ret)
-    .add_input(cond_bool)
-    .add_input(self)
-    .add_input(other)
-    .build();
-  where_kernel(iter.device_type(), iter);
-  return ret;
-}
-
 std::tuple<Tensor, Tensor> mode(const Tensor& self, int64_t dim, bool keepdim) {
   Tensor values = at::empty({0}, self.options());
   Tensor indices = at::empty({0}, self.options().dtype(kLong));
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index baad10f..c408ea9 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4772,12 +4772,11 @@
   device_check: NoCheck
   device_guard: False
 
-# we define both of these because 'where' does the broadcast and '_s_where' doesn't;
-# this allows us to implicitly calculate the broadcast derivative, while only dealing with the
-# _s_where derivative.
 - func: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
   device_check: NoCheck   # TensorIterator
   variants: function, method
+  dispatch:
+    CPU, CUDA: where
 
 - func: where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor
   variants: function
@@ -4792,11 +4791,6 @@
   device_check: NoCheck   # TensorIterator
   variants: function
 
-- func: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor
-  variants: function
-  dispatch:
-    CPU, CUDA: _s_where
-
 - func: norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor
   variants: function
 
diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py
index b7dc0d5..0860bfc 100644
--- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py
+++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py
@@ -110,6 +110,7 @@
     ("aten::grid_sampler_3d_backward", datetime.date(9999, 1, 1)),
     ("aten::_transform_bias_rescale_qkv", datetime.date(9999, 1, 1)),
     ("aten::_scatter_reduce.two", datetime.date(9999, 1, 1)),
+    ("aten::_s_where", datetime.date(2022, 9, 30)),
 ]
 
 ALLOW_LIST_COMPILED = [
diff --git a/test/test_torch.py b/test/test_torch.py
index 73ef010..24af923 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -116,19 +116,6 @@
 class TestTorchDeviceType(TestCase):
     exact_dtype = True
 
-    # FIXME: Port this to ErrorInputs on where
-    @onlyCUDA
-    @dtypes(torch.float32)
-    def test_where_invalid_device(self, device, dtype):
-        for devices in [('cpu', device, device), (device, 'cpu', 'cpu'),
-                        (device, 'cpu', device), ('cpu', device, 'cpu')]:
-            condition = make_tensor(16, device=devices[0], dtype=torch.float32)
-            x = make_tensor(16, device=devices[1], dtype=torch.float32)
-            y = make_tensor(16, device=devices[2], dtype=torch.float32)
-            with self.assertRaisesRegex(RuntimeError,
-                                        "Expected condition, x and y to be on the same device"):
-                torch.where(condition, x, y)
-
     # TODO: move all tensor creation to common ops
     def _rand_shape(self, dim, min_size, max_size):
         shape = []
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index fd0b773..54aac6e 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1643,7 +1643,7 @@
   self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy]
   result: at::view_as_complex(self_t)
 
-- name: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor
+- name: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
   condition: non_differentiable
   self: where(condition, grad, zeros_like(grad))
   other: where(condition, zeros_like(grad), grad)
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 4b63414..7f4c44f 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -91,7 +91,7 @@
     'triu', 'chunk', 'zero_', 'eq_', 'ne_', 'add', '__radd__', 'sum',
     '_conj', 'sin', 'cos', 'mul', 'sinc', 'sinh', 'cosh', '__rmul__',
     'sgn', 'asin', 'acos', 'sub', 'div', 'cat', 'view_as_complex', 'index_put',
-    'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd',
+    'neg', 'complex', 'select', 'where', 'as_strided', 'slice', 'constant_pad_nd',
     'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
     'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'outer',
     'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 67e2aa4..f1dc434 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -7096,6 +7096,16 @@
                           args=(make_bool_mask(mask_shape), make_arg(other_shape)),
                           broadcasts_input=broadcasts_input)
 
+def error_inputs_where(op_info, device, **kwargs):
+    shape = (S,)
+    err_msg = "Expected all tensors to be on the same device"
+    for devices in product(('cpu', device), repeat=3):
+        if len(set(devices)) == 2:
+            si = SampleInput(make_tensor(shape, device=devices[0], dtype=torch.float32),
+                             args=(make_tensor(shape, dtype=torch.bool, device=devices[1]),
+                             make_tensor(shape, device=devices[2], dtype=torch.float32)))
+            yield ErrorInput(si, error_type=RuntimeError, error_regex=err_msg)
+
 def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs):
     make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
 
@@ -14539,9 +14549,12 @@
            op=lambda self, condition, other: torch.where(condition, self, other),
            ref=lambda self, condition, other: np.where(condition, self, other),
            sample_inputs_func=sample_inputs_where,
+           error_inputs_func=error_inputs_where,
            supports_out=False,
            supports_forward_ad=True,
            supports_fwgrad_bwgrad=True,
+           decorators=(
+               DecorateInfo(onlyCUDA, "TestCommon", 'test_errors'),),
            skips=(
                # test does not work with passing lambda for op
                # AssertionError: False is not true :