blob: cefb167ebef42cb95b9cbc4e4764c16b3dc42d78 [file] [log] [blame]
# Owner(s): ["module: unknown"]
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import itertools
from torch.testing._internal.jit_utils import RUN_CUDA
import unittest
from torch._subclasses import FakeTensor
class FakeTensorTest(TestCase):
def test_basic(self):
x = FakeTensor.from_tensor(torch.empty(2, 2, device="cpu"))
y = x = FakeTensor.from_tensor(torch.empty(4, 2, 2, device="cpu"))
y = x + x
self.assertEqual(y.shape, (4, 2, 2))
self.assertEqual(y.device, torch.device("cpu"))
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_shape_take_not_device(self):
x = FakeTensor.from_tensor(torch.empty(1, device="cpu"))
y = FakeTensor.from_tensor(torch.empty(8, 8, device="cuda"))
out = x.resize_as_(y)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.device.type, "cpu")
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_zero_dim(self):
x = FakeTensor.from_tensor(torch.tensor(0.0))
y = FakeTensor.from_tensor(torch.rand([4, 4], device="cuda"))
out = x + y
self.assertEqual(out.shape, (4, 4))
self.assertEqual(out.device, y.device)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_throw(self):
x = FakeTensor.from_tensor(torch.tensor(0.0))
y = FakeTensor.from_tensor(torch.rand([4, 4], device="cuda"))
z = FakeTensor.from_tensor(torch.rand([4, 4], device="cpu"))
self.assertRaises(Exception, lambda: torch.lerp(x, y, z))
def test_dispatch_device(self):
x = FakeTensor.from_tensor(torch.rand([4, 4]))
self.assertEqual(x.device.type, "cpu")
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_type_as(self):
x = FakeTensor.from_tensor(torch.rand([16, 1], device='cpu'))
y = FakeTensor.from_tensor(torch.rand([4, 4], device='cuda'))
out = x.type_as(y)
self.assertEqual(out.device.type, "cuda")
def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
return maybe_contained_type.isSubtypeOf(type) or any(
contains_type(e, maybe_contained_type) for e in type.containedTypes()
)
class FakeTensorOperatorInvariants(TestCase):
def test_non_kwarg_only_device(self):
def get_op(schema):
namespace, name = schema.name.split("::")
overload = schema.overload_name if schema.overload_name else "default"
assert namespace == "aten"
return getattr(getattr(torch.ops.aten, name), overload)
for schema in torch._C._jit_get_all_schemas():
namespace = schema.name.split("::")[0]
if namespace != "aten":
continue
ten_type = torch._C.TensorType.get()
if not any(
contains_type(arg.type, ten_type)
for arg in itertools.chain(schema.arguments, schema.returns)
):
continue
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
has_non_kwarg_device = any(
not arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
for arg in schema.arguments
)
if has_non_kwarg_device:
self.assertTrue(
get_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops
)
if __name__ == "__main__":
run_tests()