Enable `torch.isclose` to suppport bool tensors (#61271)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/60533
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61271
Reviewed By: zhxchen17
Differential Revision: D29737618
Pulled By: SplitInfinity
fbshipit-source-id: 45314bc7e0b9a28c10700455b1e6267c0db3eefc
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index bccd2f7..0217028 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -133,14 +133,16 @@
// by the default scalar type then this may cause an incorrect result.
// Computes allowed and actual error
- Tensor cast_other;
+ Tensor cast_self, cast_other;
+ cast_self = self.scalar_type() == at::kBool ? self.to(at::get_default_dtype()) : self;
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
cast_other = other.to(at::get_default_dtype());
} else {
cast_other = other;
}
+
Tensor allowed_error = atol + (rtol * cast_other).abs();
- Tensor actual_error = (self - cast_other).abs();
+ Tensor actual_error = (cast_self - cast_other).abs();
// Computes finite closeness
close.__ior__(at::isfinite(actual_error).__iand__(actual_error <= allowed_error));
diff --git a/test/test_testing.py b/test/test_testing.py
index 3f7709d..9b4c73a 100644
--- a/test/test_testing.py
+++ b/test/test_testing.py
@@ -244,8 +244,6 @@
expected = test[2]
self.assertEqual(actual.item(), expected)
- # torch.close is not implemented for bool tensors
- # see https://github.com/pytorch/pytorch/issues/33048
def test_isclose_comparetensors_bool(self, device):
tests = (
(True, True, True),
@@ -254,9 +252,7 @@
(False, True, False),
)
- with self.assertRaises(RuntimeError):
- self._isclose_helper(tests, device, torch.bool, False)
-
+ self._isclose_helper(tests, device, torch.bool, False)
self._comparetensors_helper(tests, device, torch.bool, False)
@dtypes(torch.uint8,