Enable `torch.isclose` to suppport bool tensors (#61271)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/60533

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61271

Reviewed By: zhxchen17

Differential Revision: D29737618

Pulled By: SplitInfinity

fbshipit-source-id: 45314bc7e0b9a28c10700455b1e6267c0db3eefc
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index bccd2f7..0217028 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -133,14 +133,16 @@
   // by the default scalar type then this may cause an incorrect result.
 
   // Computes allowed and actual error
-  Tensor cast_other;
+  Tensor cast_self, cast_other;
+  cast_self = self.scalar_type() == at::kBool ? self.to(at::get_default_dtype()) : self;
   if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
     cast_other = other.to(at::get_default_dtype());
   } else {
     cast_other = other;
   }
+
   Tensor allowed_error = atol + (rtol * cast_other).abs();
-  Tensor actual_error = (self - cast_other).abs();
+  Tensor actual_error = (cast_self - cast_other).abs();
 
   // Computes finite closeness
   close.__ior__(at::isfinite(actual_error).__iand__(actual_error <= allowed_error));
diff --git a/test/test_testing.py b/test/test_testing.py
index 3f7709d..9b4c73a 100644
--- a/test/test_testing.py
+++ b/test/test_testing.py
@@ -244,8 +244,6 @@
             expected = test[2]
             self.assertEqual(actual.item(), expected)
 
-    # torch.close is not implemented for bool tensors
-    # see https://github.com/pytorch/pytorch/issues/33048
     def test_isclose_comparetensors_bool(self, device):
         tests = (
             (True, True, True),
@@ -254,9 +252,7 @@
             (False, True, False),
         )
 
-        with self.assertRaises(RuntimeError):
-            self._isclose_helper(tests, device, torch.bool, False)
-
+        self._isclose_helper(tests, device, torch.bool, False)
         self._comparetensors_helper(tests, device, torch.bool, False)
 
     @dtypes(torch.uint8,