Check for consistent devices in at::where (#33432)

Summary:
Changelog:
- Add a check to ensure that all inputs to `where` lie on the same device
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33432

Test Plan:
- Added test_where_invalid_device

Fixes https://github.com/pytorch/pytorch/issues/33422

Differential Revision: D19981115

Pulled By: VitalyFedyunin

fbshipit-source-id: 745896927edb53f61f3dd48ba9e1e6cd10d35434
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index 595f8ab..80a69d0 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -101,10 +101,14 @@
   } else if (localScalar.isBoolean()) {
     return localScalar.to<bool>();
   }
-  AT_ERROR("expected non-Tensor backed scalar");
+  AT_ERROR("expected non-Tensor backend scalar");
 }
 
 Tensor where(const Tensor& condition, const Tensor& self, const Tensor& other) {
+  TORCH_CHECK(condition.device() == self.device() && self.device() == other.device(),
+              "expected condition, x and y to be on the same device, but condition is on ",
+              condition.device(), " and x and y are on ", self.device(), " and ", other.device(),
+              " respectively");
   if (condition.scalar_type() != ScalarType::Byte && condition.scalar_type() != ScalarType::Bool) {
     AT_ERROR("Expected condition to have ScalarType Byte, but got ScalarType ",
                   toString(condition.scalar_type()));
diff --git a/test/test_torch.py b/test/test_torch.py
index 95f7bcb..6088077 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -711,6 +711,17 @@
         test((10,))
         test((5, 5))
 
+    def test_where_invalid_device(self):
+        if torch.cuda.is_available():
+            for devices in [('cpu', 'cuda', 'cuda'), ('cuda', 'cpu', 'cpu'),
+                            ('cuda', 'cpu', 'cuda'), ('cpu', 'cuda', 'cpu')]:
+                condition = torch.rand(16, device=devices[0])
+                x = torch.rand(16, device=devices[1])
+                y = torch.rand(16, device=devices[2])
+                with self.assertRaisesRegex(RuntimeError,
+                                            "expected condition, x and y to be on the same device"):
+                    torch.where(condition, x, y)
+
     def test_where_bool_tensor(self):
         for d in torch.testing.get_all_device_types():
             a = torch.tensor([True, False], device=d)