[quant][core][improvement] Added support for data_ptr<T> for quantized tensors to return pointer to underlying int type (e.g., int8* instead of qint*)
Summary:
Previously, data_ptr<T> did not provide support for returning a pointer to the underlying int tensor.
Instead, it returned a pointer to a quantized tensor (e.g., qint*), and backend users had to call, e.g.,
reinterpret_cast<int8*>(quantized_tensor.data_ptr()) to cast it to an int* pointer.
This PR enables direct support for returning the underlying pointer without need for casting.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75643
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/templates/TensorMethods.cpp b/aten/src/ATen/templates/TensorMethods.cpp
index be9d944..dd8f3c3 100644
--- a/aten/src/ATen/templates/TensorMethods.cpp
+++ b/aten/src/ATen/templates/TensorMethods.cpp
@@ -7,7 +7,9 @@
template <> \
TORCH_API T* TensorBase::data_ptr() const { \
TORCH_CHECK( \
- scalar_type() == ScalarType::name, \
+ scalar_type() == ScalarType::name \
+ || (isQIntType(scalar_type()) \
+ && toUnderlying(scalar_type()) == ScalarType::name), \
"expected scalar type " \
#name \
" but found ", \