add function name to error messages generated by checked_tensor_unwrap (#24187)
Summary:
Improve error messages by showing the relevant function call that failed.
Before:
```
>>> torch.ones(1, dtype=torch.float) < torch.ones(1, dtype=torch.double)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected object of scalar type Float but got scalar type Double for argument https://github.com/pytorch/pytorch/issues/2 'other'
```
After:
```
>>> torch.ones(1, dtype=torch.float) < torch.ones(1, dtype=torch.double)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected object of scalar type Float but got scalar type Double for argument https://github.com/pytorch/pytorch/issues/2 'other' in call to _th_lt
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24187
Differential Revision: D16769167
Pulled By: nairbv
fbshipit-source-id: 4992eb4e86bdac2ab8805cc5356f7f92c63e1255
diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h
index be47638..fdd8fd2 100644
--- a/aten/src/ATen/Utils.h
+++ b/aten/src/ATen/Utils.h
@@ -67,20 +67,20 @@
// TODO: Change Backend into TensorTypeId
// TODO: Stop unwrapping (this is blocked on getting rid of TH ;)
-static inline TensorImpl* checked_tensor_unwrap(const Tensor& expr, const char * name, int pos, bool allowNull, Backend backend, ScalarType scalar_type) {
+static inline TensorImpl* checked_tensor_unwrap(const Tensor& expr, const char * name, int pos, const char * api, bool allowNull, Backend backend, ScalarType scalar_type) {
if(allowNull && !expr.defined()) {
return nullptr;
}
if (tensorTypeIdToBackend(expr.type_id()) != backend) {
AT_ERROR("Expected object of backend ", backend, " but got backend ", tensorTypeIdToBackend(expr.type_id()),
- " for argument #", pos, " '", name, "'");
+ " for argument #", pos, " '", name, "' in call to ", api);
}
if (expr.scalar_type() != scalar_type) {
AT_ERROR("Expected object of scalar type ", scalar_type, " but got scalar type ", expr.scalar_type(),
- " for argument #", pos, " '", name, "'");
+ " for argument #", pos, " '", name, "' in call to ", api);
}
if (expr.is_variable()) { // TODO: change this to check `.requires_grad()` and `GradMode::is_enabled()` when Variable and Tensor are merged
- AT_ERROR("Expected Tensor (not Variable) for argument #", pos, " '", name, "'");
+ AT_ERROR("Expected Tensor (not Variable) for argument #", pos, " '", name, "' in call to ", api);
}
return expr.unsafeGetTensorImpl();
}
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index 2006d17..af4639d 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -269,42 +269,42 @@
'THTensor*':
CodeTemplate(
'checked_tensor_unwrap('
- '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '
+ '${arg_name}, "${arg_name}", ${arg_pos}, "${api_name}", ${null_okay}, '
'Backend::${Backend}, ScalarType::${ScalarName})'),
'THByteTensor*':
CodeTemplate(
'checked_tensor_unwrap('
- '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '
+ '${arg_name}, "${arg_name}", ${arg_pos}, "${api_name}", ${null_okay}, '
'Backend::${Backend}, ScalarType::Byte)'),
'THBoolTensor*':
CodeTemplate(
'checked_tensor_unwrap('
- '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '
+ '${arg_name}, "${arg_name}", ${arg_pos}, "${api_name}", ${null_okay}, '
'Backend::${Backend}, ScalarType::Bool)'),
'THIndexTensor*':
CodeTemplate(
'checked_tensor_unwrap('
- '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '
+ '${arg_name}, "${arg_name}", ${arg_pos}, "${api_name}", ${null_okay}, '
'Backend::${Backend}, ScalarType::Long)'),
'THIntegerTensor*':
CodeTemplate(
'checked_tensor_unwrap('
- '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '
+ '${arg_name}, "${arg_name}", ${arg_pos}, "${api_name}", ${null_okay}, '
'Backend::${Backend}, ScalarType::Int)'),
'THDenseTensor*':
CodeTemplate(
'checked_tensor_unwrap('
- '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '
+ '${arg_name}, "${arg_name}", ${arg_pos}, "${api_name}", ${null_okay}, '
'Backend::${DenseBackend}, ScalarType::${ScalarName})'),
'THDenseIndexTensor*':
CodeTemplate(
'checked_tensor_unwrap('
- '${arg_name},"${arg_name}",${arg_pos}, ${null_okay}, '
+ '${arg_name}, "${arg_name}", ${arg_pos}, "${api_name}", ${null_okay}, '
'Backend::${DenseBackend}, ScalarType::Long)'),
'THStorage*':
CodeTemplate(
'checked_storage('
- '${arg_name},"${arg_name}",${arg_pos}, '
+ '${arg_name}, "${arg_name}", ${arg_pos}, '
# We're punning here (Backend and DeviceType constructors coincide)
# but DeviceType is the correct way to classify storages
'DeviceType::${Backend}, at::scalarTypeToTypeMeta(ScalarType::${ScalarName}))'),
@@ -1479,8 +1479,8 @@
check_cast = CHECKED_CAST[arg['type']].substitute(
case_env, arg_name=arg['name'], arg_pos=count,
- null_okay=null_okay, default_init=default_init,
- size=arg.get('size'))
+ api_name=option['api_name'], null_okay=null_okay,
+ default_init=default_init, size=arg.get('size'))
case_body.append("auto {}_ = {};".format(
arg['name'], check_cast))
if drop_argument(arg, option):
diff --git a/test/test_torch.py b/test/test_torch.py
index 000faed..5d0326e 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -12699,6 +12699,11 @@
e1.fill_diagonal_(v, wrap=True)
self.assertEqual(e1, e2)
+ def test_function_unwrap_message(self):
+ self.assertRaisesRegex(RuntimeError, ' call to _th_lt',
+ lambda: torch.ones(1, dtype=torch.float) < torch.ones(1, dtype=torch.double))
+
+
# Functions to test negative dimension wrapping
METHOD = 1
INPLACE_METHOD = 2