blob: 77e05ad609005aca6a4d2df49178905d8a62e4af [file] [log] [blame]
# Owner(s): ["oncall: pt2"]
import tempfile
import unittest
import torch
from torch._prims.debug_prims import load_tensor_reader
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.multiprocessing.reductions import StorageWeakRef
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
skipIfRocm,
TestCase,
)
from torch.utils._content_store import (
ContentStoreReader,
ContentStoreWriter,
hash_storage,
)
@unittest.skipIf(IS_WINDOWS, "Test case not supported on Windows")
class TestContentStore(TestCase):
def test_basic(self, device):
# setup test data
x = torch.randn(4, device=device)
y = torch.randn(6, device=device)
z = x.view(2, 2)
# start writing
with tempfile.TemporaryDirectory() as loc:
writer = ContentStoreWriter(loc)
writer.write_tensor("x", x)
writer.write_tensor("y", y)
writer.write_tensor("z", z)
# do some mutation that is VC UNTRACKED
x.data.add_(1)
writer.write_tensor("x2", x)
writer.write_tensor("y2", y)
writer.write_tensor("z2", z)
del writer
reader = ContentStoreReader(loc)
n_x = reader.read_tensor("x")
n_y = reader.read_tensor("y")
n_z = reader.read_tensor("z")
self.assertEqual(n_x + 1, x)
self.assertEqual(n_y, y)
self.assertEqual(n_z + 1, z)
self.assertEqual(
StorageWeakRef(n_x.untyped_storage()),
StorageWeakRef(n_z.untyped_storage()),
)
n_x2 = reader.read_tensor("x2")
n_y2 = reader.read_tensor("y2")
n_z2 = reader.read_tensor("z2")
self.assertEqual(n_x2, x)
self.assertEqual(n_y2, y)
self.assertEqual(n_z2, z)
self.assertEqual(
StorageWeakRef(n_y2.untyped_storage()),
StorageWeakRef(n_y.untyped_storage()),
)
def test_scalar(self, device):
# Should not raise an error
hash_storage(torch.tensor(2, device=device).untyped_storage())
@torch._dynamo.config.patch(cache_size_limit=1)
def test_repeated_hash(self, device):
# Test that repeated hashing doesn't trigger a recompile in dynamo
# If it does, we will execute prims.xor_sum in eager which fails
for _ in range(4):
hash_storage(torch.tensor(2, device=device).untyped_storage())
@skipIfRocm
def test_load_tensor(self, device):
with tempfile.TemporaryDirectory() as loc:
writer = ContentStoreWriter(loc)
x = torch.randn(4, device=device)
def same_meta_as_x(t):
self.assertEqual(t.size(), x.size())
self.assertEqual(t.stride(), x.stride())
self.assertEqual(t.dtype, x.dtype)
self.assertEqual(t.device, x.device)
writer.write_tensor("x", x)
with load_tensor_reader(loc):
x2 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertEqual(x, x2)
x3 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertEqual(x, x3)
# Must not alias!
self.assertNotEqual(
StorageWeakRef(x.untyped_storage()),
StorageWeakRef(x2.untyped_storage()),
)
self.assertNotEqual(
StorageWeakRef(x2.untyped_storage()),
StorageWeakRef(x3.untyped_storage()),
)
# Check fake tensor mode works too
with FakeTensorMode():
x4 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertIsInstance(x4, FakeTensor)
same_meta_as_x(x4)
# Check fp64 works
x5 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float64, device=device
)
self.assertEqual(x5.float(), x)
self.assertEqual(x5.dtype, torch.float64)
x6 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
same_meta_as_x(x6)
instantiate_device_type_tests(TestContentStore, globals())
if __name__ == "__main__":
run_tests()