Add `torch._check*` functions analogous to C++ `TORCH_CHECK*` (#88725)
Adds `_check`, `_check_index`, `_check_value`, `_check_type`, `_check_not_implemented`, `_check_tensor_all`
Part of #72948
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88725
Approved by: https://github.com/albanD
diff --git a/test/test_torch.py b/test/test_torch.py
index 71825b1..b92ce39 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -734,8 +734,73 @@
self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='mean').shape)
self.assertEqual((), torch.nn.functional.multi_margin_loss(input, target, reduction='sum').shape)
+ # Test that `torch._check_tensor_all` raises errors in the correct cases
+ def test_check_tensor_all(self, device):
+ default_message = 'Expected cond to be True'
+ check_fn = torch._check_tensor_all
+ expected_error = RuntimeError
+
+ # cond must be a tensor
+ with self.assertRaisesRegex(TypeError, 'cond must be a tensor'):
+ check_fn(True)
+
+ # cond tensor must be boolean
+ with self.assertRaisesRegex(TypeError, 'cond tensor must have dtype torch.bool'):
+ check_fn(torch.ones(1, device=device))
+
+ test_sizes = [
+ (),
+ (1,),
+ (10,),
+ (1, 1),
+ (1, 10),
+ (10, 1),
+ (10, 10),
+ (1, 1, 1),
+ (10, 1, 1),
+ (1, 10, 1),
+ (10, 10, 10),
+ ]
+ for size in test_sizes:
+ t_all_true = torch.ones(size, dtype=torch.bool, device=device)
+ t_all_false = torch.zeros(size, dtype=torch.bool, device=device)
+
+ # Should not raise error
+ check_fn(t_all_true)
+
+ with self.assertRaisesRegex(expected_error, default_message):
+ check_fn(t_all_false)
+
+ if t_all_true.numel() > 1:
+ t_all_true_but_one = t_all_true.clone()
+ # Choose a random element to set to false
+ idx = (random.choice(range(dim_size)) for dim_size in size)
+ t_all_true_but_one[(..., *idx)] = False
+
+ with self.assertRaisesRegex(expected_error, default_message):
+ check_fn(t_all_true_but_one)
+
+ # Test a simple failure message
+ message = 'message'
+ with self.assertRaisesRegex(expected_error, message):
+ check_fn(t_all_false, lambda: message)
+
+ # Test message with tensor
+ def message():
+ return torch.arange(4)
+
+ with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
+ check_fn(t_all_false, message)
+
+ # Test format string message
+ def message():
+ return f"{'test'} {[1, 2, 'a', True]} {True} {100} {torch.arange(4)}"
+
+ with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
+ check_fn(t_all_false, message)
+
# Test that `TORCH_CHECK_TENSOR_ALL` raises errors that propagate from C++ to Python
- def test_check_tensor(self, device):
+ def test_check_tensor_internal(self, device):
test_sizes = [
(),
(1,),
@@ -5683,6 +5748,52 @@
with self.assertRaises(IndexError):
reference[0.0, :, 0.0] = 1
+ # Test `torch._check*` functions
+ def test_check(self):
+ test_cases = [
+ # check function, expected error
+ (torch._check, RuntimeError),
+ (torch._check_index, IndexError),
+ (torch._check_value, ValueError),
+ (torch._check_type, TypeError),
+ (torch._check_not_implemented, NotImplementedError),
+ ]
+
+ for check_fn, expected_error in test_cases:
+ # cond=True should not raise an error
+ check_fn(True)
+
+ # Test default failure message for cond=False
+ default_message = 'Expected cond to be True'
+ with self.assertRaisesRegex(expected_error, default_message):
+ check_fn(False)
+
+ # Test a simple failure message
+ message = 'message'
+ with self.assertRaisesRegex(expected_error, message):
+ check_fn(False, lambda: message)
+
+ # Test message with tensor
+ def message():
+ return torch.arange(4)
+
+ with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
+ check_fn(False, message)
+
+ # Test format string message
+ def message():
+ return f"{'test'} {[1, 2, 'a', True]} {True} {100} {torch.arange(4)}"
+
+ with self.assertRaisesRegex(expected_error, re.escape(str(message()))):
+ check_fn(False, message)
+
+ # Test incorrect `cond` arg type
+ with self.assertRaisesRegex(TypeError, 'cond must be a bool'):
+ check_fn('wrong type')
+
+ with self.assertRaisesRegex(TypeError, 'cond must be a bool'):
+ check_fn(torch.tensor(True))
+
# FIXME: move to indexing test suite
def test_index_add(self):
for device in get_all_device_types():
diff --git a/torch/__init__.py b/torch/__init__.py
index 798f6dc..712fbf9 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -920,6 +920,151 @@
return _C._get_warnAlways()
################################################################################
+# Define error checking functions
+################################################################################
+
+# These error checking functions must be kept consistent with their C++
+# equivalents. Their C++ equivalents are mentioned where applicable.
+
+def _check_with(error_type, cond, message):
+ if not isinstance(cond, builtins.bool):
+ raise TypeError(f'cond must be a bool, but got {type(cond)}')
+
+ if cond:
+ return
+
+ # error_type must be a subclass of Exception and not subclass of Warning
+ assert issubclass(error_type, Exception) and not issubclass(error_type, Warning)
+
+ if message is None:
+ message_evaluated = (
+ 'Expected cond to be True, but got False. (Could this error '
+ 'message be improved? If so, please report an enhancement request '
+ 'to PyTorch.)')
+
+ else:
+ if not callable(message):
+ raise TypeError('message must be a callable')
+
+ message_evaluated = str(message())
+
+ raise error_type(message_evaluated)
+
+def _check(cond, message=None):
+ r"""Throws error containing an optional message if the specified condition
+ is False.
+
+ Error type: ``RuntimeError``
+
+ C++ equivalent: ``TORCH_CHECK``
+
+ Args:
+ cond (:class:`bool`): If False, throw error
+
+ message (Callable, optional): Callable that returns either a string or
+ an object that has a ``__str__()`` method to be used as the error
+ message. Default: ``None``
+ """
+ _check_with(RuntimeError, cond, message)
+
+def _check_index(cond, message=None):
+ r"""Throws error containing an optional message if the specified condition
+ is False.
+
+ Error type: ``IndexError``
+
+ C++ equivalent: ``TORCH_CHECK_INDEX``
+
+ Args:
+ cond (:class:`bool`): If False, throw error
+
+ message (Callable, optional): Callable that returns either a string or
+ an object that has a ``__str__()`` method to be used as the error
+ message. Default: ``None``
+ """
+ _check_with(IndexError, cond, message)
+
+def _check_value(cond, message=None):
+ r"""Throws error containing an optional message if the specified condition
+ is False.
+
+ Error type: ``ValueError``
+
+ C++ equivalent: ``TORCH_CHECK_VALUE``
+
+ Args:
+ cond (:class:`bool`): If False, throw error
+
+ message (Callable, optional): Callable that returns either a string or
+ an object that has a ``__str__()`` method to be used as the error
+ message. Default: ``None``
+ """
+ _check_with(ValueError, cond, message)
+
+def _check_type(cond, message=None):
+ r"""Throws error containing an optional message if the specified condition
+ is False.
+
+ Error type: ``TypeError``
+
+ C++ equivalent: ``TORCH_CHECK_TYPE``
+
+ Args:
+ cond (:class:`bool`): If False, throw error
+
+ message (Callable, optional): Callable that returns either a string or
+ an object that has a ``__str__()`` method to be used as the error
+ message. Default: ``None``
+ """
+ _check_with(TypeError, cond, message)
+
+def _check_not_implemented(cond, message=None):
+ r"""Throws error containing an optional message if the specified condition
+ is False.
+
+ Error type: ``NotImplementedError``
+
+ C++ equivalent: ``TORCH_CHECK_NOT_IMPLEMENTED``
+
+ Args:
+ cond (:class:`bool`): If False, throw error
+
+ message (Callable, optional): Callable that returns either a string or
+ an object that has a ``__str__()`` method to be used as the error
+ message. Default: ``None``
+ """
+ _check_with(NotImplementedError, cond, message)
+
+def _check_tensor_all_with(error_type, cond, message=None):
+ if not torch.is_tensor(cond):
+ raise TypeError(f'cond must be a tensor, but got {type(cond)}')
+
+ if not cond.dtype == torch.bool:
+ raise TypeError(
+ f'cond tensor must have dtype torch.bool, but got {cond.dtype}')
+
+ _check_with(error_type, cond._is_all_true().item(), message)
+
+# C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
+def _check_tensor_all(cond, message=None):
+ r"""Throws error containing an optional message if the specified condition
+ is False.
+
+ Error type: ``RuntimeError``
+
+ C++ equivalent: ``TORCH_CHECK_TENSOR_ALL``
+
+ Args:
+ cond (:class:`torch.Tensor`): Tensor of dtype ``torch.bool``. If any
+ element is ``False``, throw error
+
+ message (Callable, optional): Callable that returns either a string or
+ an object that has a ``__str__()`` method to be used as the error
+ message. Default: ``None``
+ """
+ _check_tensor_all_with(RuntimeError, cond, message)
+
+################################################################################
# Define numeric constants
################################################################################