Check for consistent devices in at::where (#33432)
Summary:
Changelog:
- Add a check to ensure that all inputs to `where` lie on the same device
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33432
Test Plan:
- Added test_where_invalid_device
Fixes https://github.com/pytorch/pytorch/issues/33422
Differential Revision: D19981115
Pulled By: VitalyFedyunin
fbshipit-source-id: 745896927edb53f61f3dd48ba9e1e6cd10d35434
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index 595f8ab..80a69d0 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -101,10 +101,14 @@
} else if (localScalar.isBoolean()) {
return localScalar.to<bool>();
}
- AT_ERROR("expected non-Tensor backed scalar");
+ AT_ERROR("expected non-Tensor backend scalar");
}
Tensor where(const Tensor& condition, const Tensor& self, const Tensor& other) {
+ TORCH_CHECK(condition.device() == self.device() && self.device() == other.device(),
+ "expected condition, x and y to be on the same device, but condition is on ",
+ condition.device(), " and x and y are on ", self.device(), " and ", other.device(),
+ " respectively");
if (condition.scalar_type() != ScalarType::Byte && condition.scalar_type() != ScalarType::Bool) {
AT_ERROR("Expected condition to have ScalarType Byte, but got ScalarType ",
toString(condition.scalar_type()));
diff --git a/test/test_torch.py b/test/test_torch.py
index 95f7bcb..6088077 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -711,6 +711,17 @@
test((10,))
test((5, 5))
+ def test_where_invalid_device(self):
+ if torch.cuda.is_available():
+ for devices in [('cpu', 'cuda', 'cuda'), ('cuda', 'cpu', 'cpu'),
+ ('cuda', 'cpu', 'cuda'), ('cpu', 'cuda', 'cpu')]:
+ condition = torch.rand(16, device=devices[0])
+ x = torch.rand(16, device=devices[1])
+ y = torch.rand(16, device=devices[2])
+ with self.assertRaisesRegex(RuntimeError,
+ "expected condition, x and y to be on the same device"):
+ torch.where(condition, x, y)
+
def test_where_bool_tensor(self):
for d in torch.testing.get_all_device_types():
a = torch.tensor([True, False], device=d)