Add `torch._check*` functions analogous to C++ `TORCH_CHECK*` (#88725)

Adds `_check`, `_check_index`, `_check_value`, `_check_type`, `_check_not_implemented`, `_check_tensor_all`

Part of #72948
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88725
Approved by: https://github.com/albanD
diff --git a/test/test_torch.py b/test/test_torch.py
index 71825b1..b92ce39 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -734,8 +734,73 @@
                 self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='mean').shape)
                 self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='sum').shape)
 
+    # Test that `torch._check_tensor_all` raises errors in the correct cases
+    def test_check_tensor_all(self, device):
+        default_message = 'Expected cond to be True'
+        check_fn = torch._check_tensor_all
+        expected_error = RuntimeError
+
+        # cond must be a tensor
+        with self.assertRaisesRegex(TypeError, 'cond must be a tensor'):
+            check_fn(True)
+
+        # cond tensor must be boolean
+        with self.assertRaisesRegex(TypeError, 'cond tensor must have dtype torch.bool'):
+            check_fn(torch.ones(1, device=device))
+
+        test_sizes = [
+            (),
+            (1,),
+            (10,),
+            (1, 1),
+            (1, 10),
+            (10, 1),
+            (10, 10),
+            (1, 1, 1),
+            (10, 1, 1),
+            (1, 10, 1),
+            (10, 10, 10),
+        ]
+        for size in test_sizes:
+            t_all_true = torch.ones(size, dtype=torch.bool, device=device)
+            t_all_false = torch.zeros(size, dtype=torch.bool, device=device)
+
+            # Should not raise error
+            check_fn(t_all_true)
+
+            with self.assertRaisesRegex(expected_error, default_message):
+                check_fn(t_all_false)
+
+            if t_all_true.numel() > 1:
+                t_all_true_but_one = t_all_true.clone()
+                # Choose a random element to set to false
+                idx = (random.choice(range(dim_size)) for dim_size in size)
+                t_all_true_but_one[(..., *idx)] = False
+
+                with self.assertRaisesRegex(expected_error, default_message):
+                    check_fn(t_all_true_but_one)
+
+            # Test a simple failure message
+            message = 'message'
+            with self.assertRaisesRegex(expected_error, message):
+                check_fn(t_all_false, lambda: message)
+
+            # Test message with tensor
+            def message():
+                return torch.arange(4)
+
+            with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
+                check_fn(t_all_false, message)
+
+            # Test format string message
+            def message():
+                return f"{'test'} {[1, 2, 'a', True]} {True} {100} {torch.arange(4)}"
+
+            with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
+                check_fn(t_all_false, message)
+
     # Test that `TORCH_CHECK_TENSOR_ALL` raises errors that propagate from C++ to Python
-    def test_check_tensor(self, device):
+    def test_check_tensor_internal(self, device):
         test_sizes = [
             (),
             (1,),
@@ -5683,6 +5748,52 @@
         with self.assertRaises(IndexError):
             reference[0.0, :, 0.0] = 1
 
+    # Test `torch._check*` functions
+    def test_check(self):
+        test_cases = [
+            # check function, expected error
+            (torch._check, RuntimeError),
+            (torch._check_index, IndexError),
+            (torch._check_value, ValueError),
+            (torch._check_type, TypeError),
+            (torch._check_not_implemented, NotImplementedError),
+        ]
+
+        for check_fn, expected_error in test_cases:
+            # cond=True should not raise an error
+            check_fn(True)
+
+            # Test default failure message for cond=False
+            default_message = 'Expected cond to be True'
+            with self.assertRaisesRegex(expected_error, default_message):
+                check_fn(False)
+
+            # Test a simple failure message
+            message = 'message'
+            with self.assertRaisesRegex(expected_error, message):
+                check_fn(False, lambda: message)
+
+            # Test message with tensor
+            def message():
+                return torch.arange(4)
+
+            with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
+                check_fn(False, message)
+
+            # Test format string message
+            def message():
+                return f"{'test'} {[1, 2, 'a', True]} {True} {100} {torch.arange(4)}"
+
+            with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
+                check_fn(False, message)
+
+            # Test incorrect `cond` arg type
+            with self.assertRaisesRegex(TypeError, 'cond must be a bool'):
+                check_fn('wrong type')
+
+            with self.assertRaisesRegex(TypeError, 'cond must be a bool'):
+                check_fn(torch.tensor(True))
+
     # FIXME: move to indexing test suite
     def test_index_add(self):
         for device in get_all_device_types():
diff --git a/torch/__init__.py b/torch/__init__.py
index 798f6dc..712fbf9 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -920,6 +920,151 @@
     return _C._get_warnAlways()
 
 ################################################################################
+# Define error checking functions
+################################################################################
+
+# These error checking functions must be kept consistent with their C++
+# equivalents. Their C++ equivalents are mentioned where applicable.
+
+def _check_with(error_type, cond, message):
+    if not isinstance(cond, builtins.bool):
+        raise TypeError(f'cond must be a bool, but got {type(cond)}')
+
+    if cond:
+        return
+
+    # error_type must be a subclass of Exception and not subclass of Warning
+    assert issubclass(error_type, Exception) and not issubclass(error_type, Warning)
+
+    if message is None:
+        message_evaluated = (
+            'Expected cond to be True, but got False. (Could this error '
+            'message be improved? If so, please report an enhancement request '
+            'to PyTorch.)')
+
+    else:
+        if not callable(message):
+            raise TypeError('message must be a callable')
+
+        message_evaluated = str(message())
+
+    raise error_type(message_evaluated)
+
+def _check(cond, message=None):
+    r"""Throws error containing an optional message if the specified condition
+    is False.
+
+    Error type: ``RuntimeError``
+
+    C++ equivalent: ``TORCH_CHECK``
+
+    Args:
+        cond (:class:`bool`): If False, throw error
+
+        message (Callable, optional): Callable that returns either a string or
+            an object that has a ``__str__()`` method to be used as the error
+            message. Default: ``None``
+    """
+    _check_with(RuntimeError, cond, message)
+
+def _check_index(cond, message=None):
+    r"""Throws error containing an optional message if the specified condition
+    is False.
+
+    Error type: ``IndexError``
+
+    C++ equivalent: ``TORCH_CHECK_INDEX``
+
+    Args:
+        cond (:class:`bool`): If False, throw error
+
+        message (Callable, optional): Callable that returns either a string or
+            an object that has a ``__str__()`` method to be used as the error
+            message. Default: ``None``
+    """
+    _check_with(IndexError, cond, message)
+
+def _check_value(cond, message=None):
+    r"""Throws error containing an optional message if the specified condition
+    is False.
+
+    Error type: ``ValueError``
+
+    C++ equivalent: ``TORCH_CHECK_VALUE``
+
+    Args:
+        cond (:class:`bool`): If False, throw error
+
+        message (Callable, optional): Callable that returns either a string or
+            an object that has a ``__str__()`` method to be used as the error
+            message. Default: ``None``
+    """
+    _check_with(ValueError, cond, message)
+
+def _check_type(cond, message=None):
+    r"""Throws error containing an optional message if the specified condition
+    is False.
+
+    Error type: ``TypeError``
+
+    C++ equivalent: ``TORCH_CHECK_TYPE``
+
+    Args:
+        cond (:class:`bool`): If False, throw error
+
+        message (Callable, optional): Callable that returns either a string or
+            an object that has a ``__str__()`` method to be used as the error
+            message. Default: ``None``
+    """
+    _check_with(TypeError, cond, message)
+
+def _check_not_implemented(cond, message=None):
+    r"""Throws error containing an optional message if the specified condition
+    is False.
+
+    Error type: ``NotImplementedError``
+
+    C++ equivalent: ``TORCH_CHECK_NOT_IMPLEMENTED``
+
+    Args:
+        cond (:class:`bool`): If False, throw error
+
+        message (Callable, optional): Callable that returns either a string or
+            an object that has a ``__str__()`` method to be used as the error
+            message. Default: ``None``
+    """
+    _check_with(NotImplementedError, cond, message)
+
+def _check_tensor_all_with(error_type, cond, message=None):
+    if not torch.is_tensor(cond):
+        raise TypeError(f'cond must be a tensor, but got {type(cond)}')
+
+    if not cond.dtype == torch.bool:
+        raise TypeError(
+            f'cond tensor must have dtype torch.bool, but got {cond.dtype}')
+
+    _check_with(error_type, cond._is_all_true().item(), message)
+
+# C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
+def _check_tensor_all(cond, message=None):
+    r"""Throws error containing an optional message if the specified condition
+    is False.
+
+    Error type: ``RuntimeError``
+
+    C++ equivalent: ``TORCH_CHECK_TENSOR_ALL``
+
+    Args:
+        cond (:class:`torch.Tensor`): Tensor of dtype ``torch.bool``. If any
+            element is ``False``, throw error
+
+        message (Callable, optional): Callable that returns either a string or
+            an object that has a ``__str__()`` method to be used as the error
+            message. Default: ``None``
+    """
+    _check_tensor_all_with(RuntimeError, cond, message)
+
+################################################################################
 # Define numeric constants
 ################################################################################