Generalize IValue's aliased hash handling for opaque tensors (#70371)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70371
This PR generalizes the aliased hash handling for opaque tensors beyond MKL-DNN.
ghstack-source-id: 147328304
Test Plan: Run existing tests.
Reviewed By: zou3519
Differential Revision: D33301787
fbshipit-source-id: db741ac347e933f8d65b029cd5be5f01804a960e
(cherry picked from commit aa8822a31a1002ea0c2440041e5e8cb862666535)
diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h
index f6c9901..1b15586 100644
--- a/aten/src/ATen/core/ivalue.h
+++ b/aten/src/ATen/core/ivalue.h
@@ -286,12 +286,6 @@
private:
static bool isAliasOf(const at::Tensor& a, const at::Tensor& b) {
- // mkldnn tensors dont have views or storage, so we compare
- // based on tensor impl. //TODO: find a way to use mkldnn storage
- if (a.is_mkldnn() || b.is_mkldnn()) {
- return a.unsafeGetTensorImpl() == b.unsafeGetTensorImpl();
- }
-
if (a.is_sparse()) {
return isAliasOf(a._values(), b) || isAliasOf(a._indices(), b);
}
@@ -309,6 +303,13 @@
isAliasOf(a, b.col_indices());
}
+ // Opaque tensors such as the ones constructed by the MKL-DNN backend
+ // don't have storage so we just compare their TensorImpls.
+ // TODO: Find way to expose alias info for opaque tensors.
+ if (!a.has_storage() || !b.has_storage()) {
+ return a.unsafeGetTensorImpl() == b.unsafeGetTensorImpl();
+ }
+
return a.is_alias_of(b);
}
@@ -900,12 +901,7 @@
// Detect aliased tensors.
struct HashAliasedIValue {
size_t hashTensor(const at::Tensor& ten) const {
- if (ten.is_mkldnn()) {
- // MKLDNN tensors dont have storage and dont create views
- // or aliasing so we can just use Tensor pointer, TODO: find way
- // to use mkldnn storage
- return reinterpret_cast<size_t>(ten.unsafeGetTensorImpl());
- } else if (ten.is_sparse()) {
+ if (ten.is_sparse()) {
// COO sparse tensors have a "values" tensor and an "indices" tensor
// so this will detect overlap of sparse tensors that share a values
// tensor, but not sparse tensors that share an indices tensor.
@@ -915,6 +911,11 @@
// so this will detect overlap of sparse tensors that share a values
// tensor, but not sparse tensors that share an indices tensor.
return hashTensor(ten.values());
+ } else if (!ten.has_storage()) {
+ // Opaque tensors such as the ones constructed by the MKL-DNN backend
+ // don't have storage so we just use their TensorImpls.
+ // TODO: Find way to expose alias info for opaque tensors.
+ return reinterpret_cast<size_t>(ten.unsafeGetTensorImpl());
} else {
return reinterpret_cast<size_t>(
ten.storage().unsafeGetStorageImpl());