Revert D26837780: Revert D26819810: Revert D26815021: Revert D26744062: Add assert_async

Test Plan: revert-hammer

Differential Revision:
D26837780

Original commit changeset: 21567cab5c0f

fbshipit-source-id: 8ea735e5fdc97e32ae3fafd40297a1b8a7cd34b0
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index 17c8b5d..6787312 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -193,6 +193,10 @@
   TORCH_INTERNAL_ASSERT(false, "Expected non-Tensor backend scalar");
 }
 
+void assert_async_cpu(const Tensor& self) {
+  TORCH_CHECK(native::is_nonzero(self), "Expected Tensor with single nonzero value, but got zero");
+}
+
 namespace {
 
 // DO NOT USE THIS -- it's just an implementation detail of wrapped_scalar tensor below.
diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu
index b10ae52..186310c 100644
--- a/aten/src/ATen/native/cuda/TensorCompare.cu
+++ b/aten/src/ATen/native/cuda/TensorCompare.cu
@@ -59,4 +59,26 @@
 REGISTER_DISPATCH(isposinf_stub, &isposinf_kernel_impl);
 REGISTER_DISPATCH(isneginf_stub, &isneginf_kernel_impl);
 
+template <typename scalar_t>
+__global__ void assert_async_cuda_kernel(scalar_t* input) {
+  CUDA_KERNEL_ASSERT(input[0] != 0);
+}
+
+__global__ void assert_async_cuda_kernel(c10::complex<float>* input) {
+  CUDA_KERNEL_ASSERT(input[0] != c10::complex<float>(0, 0));
+}
+__global__ void assert_async_cuda_kernel(c10::complex<double>* input) {
+  CUDA_KERNEL_ASSERT(input[0] != c10::complex<double>(0, 0));
+}
+
+void assert_async_cuda(const Tensor& self) {
+  auto n = self.numel();
+  TORCH_CHECK(n != 0, "Boolean value of Tensor with no values is ambiguous");
+  TORCH_CHECK(n < 2, "Boolean value of Tensor with more than one value is ambiguous");
+  auto stream = at::cuda::getCurrentCUDAStream();
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "assert_async_cuda", [&] {
+    assert_async_cuda_kernel<<<1, 1, 0, stream>>>(self.data_ptr<scalar_t>());
+  });
+}
+
 }} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index a34ea83..77d0bbc 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -118,6 +118,14 @@
 
 - func: align_tensors(Tensor[] tensors) -> Tensor[]
 
+# Not assert because it's a keyword; not Assert because FX already
+# took that syntax
+# TODO: need to specify this is side-effectful somehow
+- func: assert_async(Tensor self) -> ()
+  dispatch:
+    CPU: assert_async_cpu
+    CUDA: assert_async_cuda
+
 - func: refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a)
   variants: method
 
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 60a7947..9e23891 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -1825,6 +1825,37 @@
 t2.start()
 """])
 
+    def test_cuda_assert_async(self):
+        with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"):
+            torch.assert_async(torch.tensor([], device="cuda"))
+        with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with more than one value is ambiguous"):
+            torch.assert_async(torch.tensor([0, 0], device="cuda"))
+
+        torch.assert_async(torch.tensor(1, device="cuda"))
+        torch.assert_async(torch.tensor(0.1, device="cuda"))
+        torch.assert_async(torch.tensor(-0.1, device="cuda"))
+        torch.assert_async(torch.tensor(True, device="cuda"))
+        torch.assert_async(torch.tensor(0 + 0.1j, device="cuda"))
+
+        fail_stmts = [
+            "torch.assert_async(torch.tensor(0, device='cuda'))",
+            "torch.assert_async(torch.tensor(0.0, device='cuda'))",
+            "torch.assert_async(torch.tensor(False, device='cuda'))",
+            "torch.assert_async(torch.tensor(0+ 0 j, device='cuda'))",
+        ]
+
+        import subprocess
+        for stmt in fail_stmts:
+            with self.subTest(stmt=stmt):
+                r = subprocess.call([sys.executable, '-c', f"""\
+import torch
+
+{stmt}
+torch.cuda.synchronize()
+"""])
+                self.assertTrue(r != 0)
+
+
     def test_grad_scaling_unscale(self, dtype=torch.float):
         inv_scale = torch.full((1,), 0.25, dtype=torch.float, device="cuda:0")
         found_inf = torch.full((1,), 0.0, dtype=torch.float, device="cuda:0")
diff --git a/test/test_torch.py b/test/test_torch.py
index 5c43a93..8f1072c 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -2356,6 +2356,32 @@
             self.assertTrue(torch.tensor([1]).is_nonzero())
             self.assertFalse(torch.tensor([[0]]).is_nonzero())
             self.assertTrue(torch.tensor([[1]]).is_nonzero())
+            self.assertTrue(torch.tensor(0.1).is_nonzero())
+            self.assertTrue(torch.tensor(-0.1).is_nonzero())
+            self.assertFalse(torch.tensor(0.0).is_nonzero())
+            self.assertTrue(torch.tensor(True).is_nonzero())
+            self.assertFalse(torch.tensor(False).is_nonzero())
+            self.assertFalse(torch.tensor(0 + 0j).is_nonzero())
+            self.assertTrue(torch.tensor(0 + 0.1j).is_nonzero())
+
+        def test_assert_async(self):
+            with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with no values is ambiguous"):
+                torch.assert_async(torch.tensor([]))
+            with self.assertRaisesRegex(RuntimeError, "Boolean value of Tensor with more than one value is ambiguous"):
+                torch.assert_async(torch.tensor([0, 0]))
+            with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
+                torch.assert_async(torch.tensor(0))
+            torch.assert_async(torch.tensor(1))
+            torch.assert_async(torch.tensor(0.1))
+            torch.assert_async(torch.tensor(-0.1))
+            with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
+                torch.assert_async(torch.tensor(0.0))
+            torch.assert_async(torch.tensor(True))
+            with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
+                torch.assert_async(torch.tensor(False))
+            torch.assert_async(torch.tensor(0 + 0.1j))
+            with self.assertRaisesRegex(RuntimeError, "Expected Tensor with single nonzero value, but got zero"):
+                torch.assert_async(torch.tensor(0 + 0j))
 
         # NB: we must not be built with CUDA; if we are built with CUDA but no CUDA
         # is available, we get a different error.
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 2105b7b..f29769d 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -10565,6 +10565,24 @@
     device(type='cpu')
 """)
 
+add_docstr(torch.assert_async,
+           r"""
+assert_async(tensor) -> void
+
+Asynchronously assert that the contents of tensor are nonzero.  For CPU tensors,
+this is equivalent to ``assert tensor`` or ``assert tensor.is_nonzero()``; for
+CUDA tensors, we DO NOT synchronize and you may only find out the assertion
+failed at a later CUDA kernel launch.  Asynchronous assertion can be helpful for
+testing invariants in CUDA tensors without giving up performance.  This function
+is NOT intended to be used for regular error checking, as it will trash your CUDA
+context if the assert fails (forcing you to restart your PyTorch process.)
+
+Args:
+    tensor (Tensor): a one element tensor to test to see if it is nonzero.  Zero
+        elements (including False for boolean tensors) cause an assertion failure
+        to be raised.
+""")
+
 add_docstr(torch.searchsorted,
            r"""
 searchsorted(sorted_sequence, values, *, out_int32=False, right=False, out=None) -> Tensor
diff --git a/torch/overrides.py b/torch/overrides.py
index 18fe437..bb4478c 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -272,6 +272,7 @@
         torch.argmin: lambda input: -1,
         torch.argsort: lambda input, dim=None: -1,
         torch.asin: lambda input, out=None: -1,
+        torch.assert_async: lambda input: -1,
         torch.arcsin: lambda input, out=None: -1,
         torch.asinh: lambda input, out=None: -1,
         torch.arcsinh: lambda input, out=None: -1,