blob: 0dbf088667c4304e56da9a1c994290fc93739190 [file] [log] [blame]
# Owner(s): ["oncall: export"]
import unittest
import torch
import torch.utils._pytree as pytree
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch._library.fake_class_registry import FakeScriptObject
from torch.export import export
from torch.export._trace import _export
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
find_library_location,
instantiate_parametrized_tests,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
parametrize,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.testing._internal.torchbind_impls import register_fake_operators
def load_torchbind_test_lib():
if IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library("//caffe2/test/cpp/jit:test_custom_class_registrations")
elif IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
else:
lib_file_path = find_library_location("libtorchbind_test.so")
if IS_WINDOWS:
lib_file_path = find_library_location("torchbind_test.dll")
torch.ops.load_library(str(lib_file_path))
register_fake_operators()
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
class TestExportTorchbind(TestCase):
def setUp(self):
load_torchbind_test_lib()
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class FakeFoo:
def __init__(self, x: int, y: int):
self.x = x
self.y = y
@classmethod
def from_real(cls, foo):
(x, y), _ = foo.__getstate__()
return cls(x, y)
def add_tensor(self, z):
return (self.x + self.y) * z
test = self
test.tq_push_counter = 0
test.tq_pop_counter = 0
test.tq_size_counter = 0
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue:
def __init__(self, q):
self.queue = q
@classmethod
def from_real(cls, real_tq):
ctx = torch.library.get_ctx()
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
return cls(fake_queue)
def push(self, x):
test.tq_push_counter += 1
self.queue.append(x)
def pop(self):
test.tq_pop_counter += 1
return self.queue.pop(0)
def size(self):
test.tq_size_counter += 1
return len(self.queue)
def tearDown(self):
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_Foo"
)
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_TensorQueue"
)
def _assertEqualSkipScriptObject(self, exp, actual):
flat_exp = pytree.tree_leaves(exp)
flat_actual = pytree.tree_leaves(actual)
self.assertEqual(len(flat_exp), len(flat_actual))
for a, b in zip(flat_exp, flat_actual):
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
continue
self.assertEqual(a, b)
def _test_export_same_as_eager(
self, f, args, kwargs=None, strict=True, pre_dispatch=False
):
kwargs = kwargs or {}
def export_wrapper(f, args, kwargs, strcit, pre_dispatch):
with enable_torchbind_tracing():
if pre_dispatch:
exported_program = _export(
f, args, kwargs, strict=strict, pre_dispatch=True
)
else:
exported_program = export(f, args, kwargs, strict=strict)
return exported_program
exported_program = export_wrapper(f, args, kwargs, strict, pre_dispatch)
reversed_kwargs = {key: kwargs[key] for key in reversed(kwargs)}
unlifted = exported_program.module()
exp = f(*args, **kwargs)
self.assertEqual(unlifted(*args, **kwargs), exp)
self.assertEqual(
unlifted(*args, **reversed_kwargs),
exp,
)
# check re-tracing
retraced_ep = export_wrapper(unlifted, args, kwargs, strict, pre_dispatch)
self.assertEqual(retraced_ep.module()(*args, **kwargs), exp)
return exported_program
@parametrize("pre_dispatch", [True, False])
def test_none(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x, n):
return x + self.attr.add_tensor(x)
ep = self._test_export_same_as_eager(
MyModule(),
(torch.ones(2, 3), None),
strict=False,
pre_dispatch=pre_dispatch,
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, arg_0, arg_1):
arg0_1, arg1_1, = fx_pytree.tree_flatten_spec(([arg_0, arg_1], {}), self._in_spec)
attr_1 = self.attr
call_torchbind = torch.ops.higher_order.call_torchbind(attr_1, 'add_tensor', arg0_1); attr_1 = None
add = torch.ops.aten.add.Tensor(arg0_1, call_torchbind); arg0_1 = call_torchbind = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, attr, arg0_1, arg1_1):
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', arg0_1); attr = None
add = torch.ops.aten.add.Tensor(arg0_1, call_torchbind); arg0_1 = call_torchbind = None
return (add,)""",
)
@parametrize("pre_dispatch", [True, False])
def test_attribute(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
return x + self.attr.add_tensor(x)
ep = self._test_export_same_as_eager(
MyModule(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, arg_0):
arg0_1, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
attr_1 = self.attr
call_torchbind = torch.ops.higher_order.call_torchbind(attr_1, 'add_tensor', arg0_1); attr_1 = None
add = torch.ops.aten.add.Tensor(arg0_1, call_torchbind); arg0_1 = call_torchbind = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, attr, arg0_1):
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', arg0_1); attr = None
add = torch.ops.aten.add.Tensor(arg0_1, call_torchbind); arg0_1 = call_torchbind = None
return (add,)""",
)
@parametrize("pre_dispatch", [True, False])
def test_attribute_as_custom_op_argument(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
return x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
ep = self._test_export_same_as_eager(
MyModule(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, arg_0):
arg1_1, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
attr_1 = self.attr
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr_1, arg1_1); attr_1 = None
add = torch.ops.aten.add.Tensor(arg1_1, takes_foo_default); arg1_1 = takes_foo_default = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, arg0_1, attr, arg1_1):
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.takes_foo.default, attr, arg1_1); arg0_1 = attr = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
add = torch.ops.aten.add.Tensor(arg1_1, getitem_1); arg1_1 = getitem_1 = None
return (getitem, add)""", # noqa: B950
)
@parametrize("pre_dispatch", [True, False])
def test_input(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, cc):
return x + cc.add_tensor(x)
cc = torch.classes._TorchScriptTesting._Foo(10, 20)
ep = self._test_export_same_as_eager(
MyModule(), (torch.ones(2, 3), cc), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, arg_0, arg_1):
arg0_1, arg1_1, = fx_pytree.tree_flatten_spec(([arg_0, arg_1], {}), self._in_spec)
call_torchbind = torch.ops.higher_order.call_torchbind(arg1_1, 'add_tensor', arg0_1); arg1_1 = None
add = torch.ops.aten.add.Tensor(arg0_1, call_torchbind); arg0_1 = call_torchbind = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
call_torchbind = torch.ops.higher_order.call_torchbind(arg1_1, 'add_tensor', arg0_1); arg1_1 = None
add = torch.ops.aten.add.Tensor(arg0_1, call_torchbind); arg0_1 = call_torchbind = None
return (add,)""",
)
@parametrize("pre_dispatch", [True, False])
def test_input_as_custom_op_argument(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, cc):
return x + torch.ops._TorchScriptTesting.takes_foo(cc, x)
cc = torch.classes._TorchScriptTesting._Foo(10, 20)
ep = self._test_export_same_as_eager(
MyModule(), (torch.ones(2, 3), cc), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, arg_0, arg_1):
arg1_1, arg2_1, = fx_pytree.tree_flatten_spec(([arg_0, arg_1], {}), self._in_spec)
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(arg2_1, arg1_1); arg2_1 = None
add = torch.ops.aten.add.Tensor(arg1_1, takes_foo_default); arg1_1 = takes_foo_default = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.takes_foo.default, arg2_1, arg1_1); arg0_1 = arg2_1 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
add = torch.ops.aten.add.Tensor(arg1_1, getitem_1); arg1_1 = getitem_1 = None
return (getitem, add)""", # noqa: B950
)
@parametrize("pre_dispatch", [True, False])
def test_unlift_custom_obj(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
a = torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, a)
return x + b
input = torch.ones(2, 3)
ep = self._test_export_same_as_eager(
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, arg_0):
arg1_1, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
attr_1 = self.attr
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr_1, arg1_1)
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr_1, takes_foo_default_1); attr_1 = takes_foo_default_1 = None
add = torch.ops.aten.add.Tensor(arg1_1, takes_foo_default); arg1_1 = takes_foo_default = None
return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, arg0_1, attr, arg1_1):
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.takes_foo.default, attr, arg1_1); arg0_1 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, attr, getitem_1); getitem = attr = getitem_1 = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
add = torch.ops.aten.add.Tensor(arg1_1, getitem_3); arg1_1 = getitem_3 = None
return (getitem_2, add)""", # noqa: B950
)
@parametrize("pre_dispatch", [True, False])
def test_custom_obj_list_out(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x)
y = a[0] + a[1] + a[2]
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
return x + b
input = torch.ones(2, 3)
ep = self._test_export_same_as_eager(
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, arg_0):
arg1_1, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
attr_1 = self.attr
takes_foo_list_return_default = torch.ops._TorchScriptTesting.takes_foo_list_return.default(attr_1, arg1_1)
getitem_2 = takes_foo_list_return_default[0]
getitem_3 = takes_foo_list_return_default[1]
getitem_4 = takes_foo_list_return_default[2]; takes_foo_list_return_default = None
add = torch.ops.aten.add.Tensor(getitem_2, getitem_3); getitem_2 = getitem_3 = None
add_1 = torch.ops.aten.add.Tensor(add, getitem_4); add = getitem_4 = None
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr_1, add_1); attr_1 = add_1 = None
add_2 = torch.ops.aten.add.Tensor(arg1_1, takes_foo_default); arg1_1 = takes_foo_default = None
return pytree.tree_unflatten((add_2,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, arg0_1, attr, arg1_1):
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.takes_foo_list_return.default, attr, arg1_1); arg0_1 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
getitem_2 = getitem_1[0]
getitem_3 = getitem_1[1]
getitem_4 = getitem_1[2]; getitem_1 = None
add = torch.ops.aten.add.Tensor(getitem_2, getitem_3); getitem_2 = getitem_3 = None
add_1 = torch.ops.aten.add.Tensor(add, getitem_4); add = getitem_4 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, attr, add_1); getitem = attr = add_1 = None
getitem_5 = with_effects_1[0]
getitem_6 = with_effects_1[1]; with_effects_1 = None
add_2 = torch.ops.aten.add.Tensor(arg1_1, getitem_6); arg1_1 = getitem_6 = None
return (getitem_5, add_2)""", # noqa: B950
)
@parametrize("pre_dispatch", [True, False])
def test_custom_obj_tuple_out(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
y = a[0] + a[1]
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
return x + b
input = torch.ones(2, 3)
ep = self._test_export_same_as_eager(
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, arg_0):
arg1_1, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
attr_1 = self.attr
takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(attr_1, arg1_1)
getitem_1 = takes_foo_tuple_return_default[0]
getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None
add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr_1, add); attr_1 = add = None
add_1 = torch.ops.aten.add.Tensor(arg1_1, takes_foo_default); arg1_1 = takes_foo_default = None
return pytree.tree_unflatten((add_1,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, arg0_1, attr, arg1_1):
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, attr, arg1_1); arg0_1 = None
getitem = with_effects[0]
getitem_1 = with_effects[1]
getitem_2 = with_effects[2]; with_effects = None
add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, attr, add); getitem = attr = add = None
getitem_3 = with_effects_1[0]
getitem_4 = with_effects_1[1]; with_effects_1 = None
add_1 = torch.ops.aten.add.Tensor(arg1_1, getitem_4); arg1_1 = getitem_4 = None
return (getitem_3, add_1)""", # noqa: B950
)
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
def test_make_fx_tensor_queue_methods(self, make_fx_tracing_mode):
test = self
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 2)
self.check_tq_is_fake = True
def forward(self, tq, x):
if self.check_tq_is_fake:
test.assertTrue(isinstance(tq, FakeScriptObject))
tq.push(x.cos())
tq.push(x.sin())
x_cos = tq.pop() + tq.size()
x_sin = tq.pop() - tq.size()
return x_sin, x_cos, tq
mod = Model()
tq = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
x = torch.ones(2, 3)
gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
self.assertEqual(self.tq_push_counter, 2)
self.assertEqual(self.tq_pop_counter, 2)
self.assertEqual(self.tq_size_counter, 2)
self.assertEqual(tq.size(), 0)
self.assertExpectedInline(
gm.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1):
cos = torch.ops.aten.cos.default(arg1_1)
call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'push', cos); cos = None
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'push', sin); sin = None
call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
add = torch.ops.aten.add.Tensor(call_torchbind_2, 1); call_torchbind_2 = None
call_torchbind_4 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_5 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
sub = torch.ops.aten.sub.Tensor(call_torchbind_4, 0); call_torchbind_4 = None
return (sub, add, arg0_1)
""",
)
mod.check_tq_is_fake = False
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
def test_make_fx_tensor_queue_methods_fakify_internal_states(
self, make_fx_tracing_mode
):
test = self
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 2)
self.check_tq_is_fake = True
self.current_test = test
def forward(self, tq, x):
if self.check_tq_is_fake:
self.current_test.assertTrue(isinstance(tq, FakeScriptObject))
x_cos = tq.pop() + tq.size() + x
x_sin = tq.pop() - tq.size() + x
return x_sin, x_cos, tq
mod = Model()
tq = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
for _ in range(2):
tq.push(torch.ones(2, 3))
tq1.push(torch.ones(2, 3))
x = torch.ones(2, 3)
prev_size = tq.size()
gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
self.assertEqual(self.tq_push_counter, 0)
self.assertEqual(self.tq_pop_counter, 2)
self.assertEqual(self.tq_size_counter, 2)
self.assertEqual(tq.size(), prev_size)
self.assertExpectedInline(
gm.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1):
call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
add = torch.ops.aten.add.Tensor(call_torchbind, 1); call_torchbind = None
add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = None
call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
sub = torch.ops.aten.sub.Tensor(call_torchbind_2, 0); call_torchbind_2 = None
add_2 = torch.ops.aten.add.Tensor(sub, arg1_1); sub = arg1_1 = None
return (add_2, add_1, arg0_1)
""",
)
# turn off tq type checking in eager execution
mod.check_tq_is_fake = False
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
self.assertEqual(tq.size(), 0)
self.assertEqual(tq1.size(), 0)
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
class TestRegisterFakeClass(TestCase):
def setUp(self):
load_torchbind_test_lib()
def tearDown(self):
torch._library.fake_class_registry.global_fake_class_registry.clear()
def test_register_fake_class_no_torch_bind_class(self):
with self.assertRaisesRegex(RuntimeError, "Tried to instantiate class"):
@torch._library.register_fake_class("_TorchScriptTesting::NOT_A_VALID_NAME")
class Invalid:
pass
def test_register_fake_class_no_from_real(self):
with self.assertRaisesRegex(RuntimeError, "define a classmethod from_real"):
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class InvalidFakeFoo:
def __init__(self):
pass
def test_register_fake_class_from_real_not_classmethod(self):
with self.assertRaisesRegex(RuntimeError, "is not a classmethod"):
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class FakeFoo:
def __init__(self, x, y):
self.x = x
self.y = y
def from_real(self, foo_obj):
x, y = foo_obj.__getstate__()
return FakeFoo(x, y)
def test_register_fake_class_valid(self):
class FakeFoo:
def __init__(self, x, y):
self.x = x
self.y = y
@classmethod
def from_real(cls, foo_obj):
x, y = foo_obj.__getstate__()
return cls(x, y)
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
def test_register_fake_class_duplicate_registration(self):
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class FakeFoo:
def __init__(self, x, y):
self.x = x
self.y = y
@classmethod
def from_real(cls, foo_obj):
x, y = foo_obj.__getstate__()
return cls(x, y)
with self.assertWarnsRegex(UserWarning, "already registered"):
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
instantiate_parametrized_tests(TestExportTorchbind)
if __name__ == "__main__":
run_tests()