Preserve coalesce state in sparse COO tensor serialization (#102647)
Fixes #101186
Also, resolves the "serialization to preserve coalesced-ness" part in https://github.com/pytorch/pytorch/issues/73479
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102647
Approved by: https://github.com/mikaylagawarecki
diff --git a/test/test_serialization.py b/test/test_serialization.py
index 754d824..2847732 100644
--- a/test/test_serialization.py
+++ b/test/test_serialization.py
@@ -316,7 +316,7 @@
torch.save({"tensor": x}, f)
f.seek(0)
y = torch.load(f, weights_only=weights_only)
- self.assertEqual(x, y["tensor"])
+ self.assertEqual(x, y["tensor"], exact_is_coalesced=True)
_test_serialization(lambda x: x.to_sparse())
_test_serialization(lambda x: x.to_sparse_csr())
_test_serialization(lambda x: x.to_sparse_csc())
diff --git a/torch/_tensor.py b/torch/_tensor.py
index 9e14d5e..fd4e497 100644
--- a/torch/_tensor.py
+++ b/torch/_tensor.py
@@ -329,7 +329,7 @@
if self.layout == torch.sparse_coo:
args_sparse = (
self.layout,
- (self._indices(), self._values(), self.size()),
+ (self._indices(), self._values(), self.size(), self.is_coalesced()),
)
else:
raise NotImplementedError(
diff --git a/torch/_utils.py b/torch/_utils.py
index 582ce60..a525144 100644
--- a/torch/_utils.py
+++ b/torch/_utils.py
@@ -237,8 +237,15 @@
data (tuple): The tensor's sparse storage representation.
"""
if layout == torch.sparse_coo:
- indices, values, size = data
+ if len(data) == 3:
+ # For BC:
+ indices, values, size = data
+ is_coalesced = None
+ else:
+ indices, values, size, is_coalesced = data
result = torch.sparse_coo_tensor(indices, values, size, check_invariants=False)
+ if is_coalesced is not None:
+ result._coalesced_(is_coalesced)
_sparse_tensors_to_validate.append(result)
return result