blob: 3a6445482cc17dd40adf26a4f39fdf05f4c2ecdf [file] [log] [blame]
# Owner(s): ["oncall: export"]
import torch
import torch.utils._pytree as pytree
from torch._functorch.aot_autograd import aot_export_module
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch._library.fake_class_registry import FakeScriptObject
from torch.export import export
from torch.export._trace import _export
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
class TestExportTorchbind(TestCase):
def setUp(self):
init_torchbind_implementations()
test = self
test.tq_push_counter = 0
test.tq_pop_counter = 0
test.tq_size_counter = 0
test.foo_add_tensor_counter = 0
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class FakeFoo:
def __init__(self, x: int, y: int):
self.x = x
self.y = y
@classmethod
def from_real(cls, foo):
(x, y), _ = foo.__getstate__()
return cls(x, y)
def add_tensor(self, z):
test.foo_add_tensor_counter += 1
return (self.x + self.y) * z
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue:
def __init__(self, q):
self.queue = q
@classmethod
def from_real(cls, real_tq):
ctx = torch.library.get_ctx()
fake_queue = [ctx.to_fake_tensor(t) for t in real_tq.get_raw_queue()]
return cls(fake_queue)
def push(self, x):
test.tq_push_counter += 1
self.queue.append(x)
def pop(self):
test.tq_pop_counter += 1
return self.queue.pop(0)
def size(self):
test.tq_size_counter += 1
return len(self.queue)
self.torch_bind_ops = [
torch.ops._TorchScriptTesting.takes_foo,
torch.ops._TorchScriptTesting.takes_foo_python_meta,
torch.ops._TorchScriptTesting.takes_foo_list_return,
torch.ops._TorchScriptTesting.takes_foo_tuple_return,
torch.ops._TorchScriptTesting.take_an_instance,
torch.ops._TorchScriptTesting.take_an_instance_inferred,
torch.ops._TorchScriptTesting.takes_foo_cia,
torch.ops._TorchScriptTesting.queue_pop,
torch.ops._TorchScriptTesting.queue_push,
torch.ops._TorchScriptTesting.queue_size,
]
def tearDown(self):
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_Foo"
)
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_TensorQueue"
)
def _assertEqualSkipScriptObject(self, exp, actual):
flat_exp = pytree.tree_leaves(exp)
flat_actual = pytree.tree_leaves(actual)
self.assertEqual(len(flat_exp), len(flat_actual))
for a, b in zip(flat_exp, flat_actual):
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
continue
self.assertEqual(a, b)
def _test_export_same_as_eager(
self, f, args, kwargs=None, strict=True, pre_dispatch=False
):
kwargs = kwargs or {}
def export_wrapper(f, args, kwargs, strcit, pre_dispatch):
with enable_torchbind_tracing():
if pre_dispatch:
exported_program = _export(
f, args, kwargs, strict=strict, pre_dispatch=True
)
else:
exported_program = export(f, args, kwargs, strict=strict)
return exported_program
exported_program = export_wrapper(f, args, kwargs, strict, pre_dispatch)
reversed_kwargs = {key: kwargs[key] for key in reversed(kwargs)}
unlifted = exported_program.module()
exp = f(*args, **kwargs)
self.assertEqual(unlifted(*args, **kwargs), exp)
self.assertEqual(
unlifted(*args, **reversed_kwargs),
exp,
)
# check re-tracing
retraced_ep = export_wrapper(unlifted, args, kwargs, strict, pre_dispatch)
self.assertEqual(retraced_ep.module()(*args, **kwargs), exp)
return exported_program
@parametrize("pre_dispatch", [True, False])
def test_none(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x, n):
return x + self.attr.add_tensor(x)
ep = self._test_export_same_as_eager(
MyModule(),
(torch.ones(2, 3), None),
strict=False,
pre_dispatch=pre_dispatch,
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x, n):
x, n, = fx_pytree.tree_flatten_spec(([x, n], {}), self._in_spec)
attr = self.attr
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, obj_attr, x, n):
call_torchbind = torch.ops.higher_order.call_torchbind(obj_attr, 'add_tensor', x); obj_attr = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return (add,)""",
)
@parametrize("pre_dispatch", [True, False])
def test_attribute(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
return x + self.attr.add_tensor(x)
ep = self._test_export_same_as_eager(
MyModule(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, obj_attr, x):
call_torchbind = torch.ops.higher_order.call_torchbind(obj_attr, 'add_tensor', x); obj_attr = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return (add,)""",
)
@parametrize("pre_dispatch", [True, False])
def test_attribute_as_custom_op_argument(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
return x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
ep = self._test_export_same_as_eager(
MyModule(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x); attr = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, token, obj_attr, x):
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, x); token = obj_attr = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None
return (getitem, add)""", # noqa: B950
)
@parametrize("pre_dispatch", [True, False])
@parametrize("fakify_script_obj", [True, False])
def test_input(self, pre_dispatch, fakify_script_obj):
cc = torch.classes._TorchScriptTesting._Foo(10, 20)
if not fakify_script_obj:
qual_name = cc._type().qualified_name() # type: ignore[att-defined]
if torch._library.fake_class_registry.has_fake_class(qual_name):
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_Foo"
)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, cc):
return x + cc.add_tensor(x)
ep = self._test_export_same_as_eager(
MyModule(), (torch.ones(2, 3), cc), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x, cc):
x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec)
call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x); cc = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, x, cc):
call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x); cc = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return (add,)""",
)
# aot_export_function runs the program twice
# in run_functionalized_fw_and_collect_metadata and create_aot_dispatcher_function
# We also have a re-tracing test, which doubles the count.
if fakify_script_obj:
self.assertEqual(self.foo_add_tensor_counter, 4)
@parametrize("pre_dispatch", [True, False])
@parametrize("fakify_script_obj", [True, False])
def test_input_as_custom_op_argument(self, pre_dispatch, fakify_script_obj):
cc = torch.classes._TorchScriptTesting._Foo(10, 20)
if not fakify_script_obj:
qual_name = cc._type().qualified_name() # type: ignore[att-defined]
if torch._library.fake_class_registry.has_fake_class(qual_name):
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_Foo"
)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, cc):
return x + torch.ops._TorchScriptTesting.takes_foo(cc, x)
del torch.ops._TorchScriptTesting.takes_foo.default.py_kernels[
torch._C.DispatchKey.Meta
]
torch.ops._TorchScriptTesting.takes_foo.default._dispatch_cache.clear()
# Even though a C++ implementation for takes_foo.default is registered,
# we still need the python implementation for takes_foo.default to trace with FakeFoo.
if fakify_script_obj:
with self.assertRaisesRegex(
RuntimeError, "no python implementation is found"
):
self._test_export_same_as_eager(
MyModule(),
(torch.ones(2, 3), cc),
strict=False,
pre_dispatch=pre_dispatch,
)
torch.ops._TorchScriptTesting.takes_foo.default.py_impl(
torch._C.DispatchKey.Meta
)(lambda cc, x: cc.add_tensor(x))
ep = self._test_export_same_as_eager(
MyModule(),
(torch.ones(2, 3), cc),
strict=False,
pre_dispatch=pre_dispatch,
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x, cc):
x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec)
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(cc, x); cc = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, token, x, cc):
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, cc, x); token = cc = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None
return (getitem, add)""", # noqa: B950
)
@parametrize("pre_dispatch", [True, False])
def test_unlift_custom_obj(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
a = torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, a)
return x + b
input = torch.ones(2, 3)
ep = self._test_export_same_as_eager(
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, token, obj_attr, x):
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, x); token = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, getitem_1); getitem = obj_attr = getitem_1 = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
add = torch.ops.aten.add.Tensor(x, getitem_3); x = getitem_3 = None
return (getitem_2, add)""", # noqa: B950
)
@parametrize("pre_dispatch", [True, False])
def test_custom_obj_list_out(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x)
y = a[0] + a[1] + a[2]
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
return x + b
input = torch.ones(2, 3)
ep = self._test_export_same_as_eager(
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
takes_foo_list_return_default = torch.ops._TorchScriptTesting.takes_foo_list_return.default(attr, x)
getitem_2 = takes_foo_list_return_default[0]
getitem_3 = takes_foo_list_return_default[1]
getitem_4 = takes_foo_list_return_default[2]; takes_foo_list_return_default = None
add = torch.ops.aten.add.Tensor(getitem_2, getitem_3); getitem_2 = getitem_3 = None
add_1 = torch.ops.aten.add.Tensor(add, getitem_4); add = getitem_4 = None
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, add_1); attr = add_1 = None
add_2 = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
return pytree.tree_unflatten((add_2,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, token, obj_attr, x):
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_list_return.default, obj_attr, x); token = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
getitem_2 = getitem_1[0]
getitem_3 = getitem_1[1]
getitem_4 = getitem_1[2]; getitem_1 = None
add = torch.ops.aten.add.Tensor(getitem_2, getitem_3); getitem_2 = getitem_3 = None
add_1 = torch.ops.aten.add.Tensor(add, getitem_4); add = getitem_4 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add_1); getitem = obj_attr = add_1 = None
getitem_5 = with_effects_1[0]
getitem_6 = with_effects_1[1]; with_effects_1 = None
add_2 = torch.ops.aten.add.Tensor(x, getitem_6); x = getitem_6 = None
return (getitem_5, add_2)""", # noqa: B950
)
@parametrize("pre_dispatch", [True, False])
def test_custom_obj_tuple_out(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
y = a[0] + a[1]
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
return x + b
input = torch.ones(2, 3)
ep = self._test_export_same_as_eager(
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(attr, x)
getitem_1 = takes_foo_tuple_return_default[0]
getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None
add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, add); attr = add = None
add_1 = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
return pytree.tree_unflatten((add_1,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, token, obj_attr, x):
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, obj_attr, x); token = None
getitem = with_effects[0]
getitem_1 = with_effects[1]
getitem_2 = with_effects[2]; with_effects = None
add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add); getitem = obj_attr = add = None
getitem_3 = with_effects_1[0]
getitem_4 = with_effects_1[1]; with_effects_1 = None
add_1 = torch.ops.aten.add.Tensor(x, getitem_4); x = getitem_4 = None
return (getitem_3, add_1)""", # noqa: B950
)
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
def test_make_fx_tensor_queue_methods(self, make_fx_tracing_mode):
test = self
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 2)
self.check_tq_is_fake = True
def forward(self, tq, x):
if self.check_tq_is_fake:
test.assertTrue(isinstance(tq, FakeScriptObject))
tq.push(x.cos())
tq.push(x.sin())
x_cos = tq.pop() + tq.size()
x_sin = tq.pop() - tq.size()
return x_sin, x_cos, tq
mod = Model()
tq = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
x = torch.ones(2, 3)
gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
self.assertEqual(self.tq_push_counter, 2)
self.assertEqual(self.tq_pop_counter, 2)
self.assertEqual(self.tq_size_counter, 2)
self.assertEqual(tq.size(), 0)
self.assertExpectedInline(
gm.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1):
cos = torch.ops.aten.cos.default(arg1_1)
call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'push', cos); cos = None
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'push', sin); sin = None
call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
add = torch.ops.aten.add.Tensor(call_torchbind_2, 1); call_torchbind_2 = None
call_torchbind_4 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_5 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
sub = torch.ops.aten.sub.Tensor(call_torchbind_4, 0); call_torchbind_4 = None
return (sub, add, arg0_1)
""",
)
mod.check_tq_is_fake = False
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
def test_make_fx_tensor_queue_methods_fakify_internal_states(
self, make_fx_tracing_mode
):
test = self
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 2)
self.check_tq_is_fake = True
self.current_test = test
def forward(self, tq, x):
if self.check_tq_is_fake:
self.current_test.assertTrue(isinstance(tq, FakeScriptObject))
x_cos = tq.pop() + tq.size() + x
x_sin = tq.pop() - tq.size() + x
return x_sin, x_cos, tq
mod = Model()
tq = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
for _ in range(2):
tq.push(torch.ones(2, 3))
tq1.push(torch.ones(2, 3))
x = torch.ones(2, 3)
prev_size = tq.size()
gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
self.assertEqual(self.tq_push_counter, 0)
self.assertEqual(self.tq_pop_counter, 2)
self.assertEqual(self.tq_size_counter, 2)
self.assertEqual(tq.size(), prev_size)
self.assertExpectedInline(
gm.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1):
call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
add = torch.ops.aten.add.Tensor(call_torchbind, 1); call_torchbind = None
add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = None
call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
sub = torch.ops.aten.sub.Tensor(call_torchbind_2, 0); call_torchbind_2 = None
add_2 = torch.ops.aten.add.Tensor(sub, arg1_1); sub = arg1_1 = None
return (add_2, add_1, arg0_1)
""",
)
# turn off tq type checking in eager execution
mod.check_tq_is_fake = False
self._assertEqualSkipScriptObject(gm(tq, x), mod(tq1, x))
self.assertEqual(tq.size(), 0)
self.assertEqual(tq1.size(), 0)
def test_identifying_torchbind_ops(self):
for op in self.torch_bind_ops:
self.assertTrue(op._has_torchbind_op_overload)
for op in [
torch.ops.aten.add,
torch.ops.aten.cos,
]:
self.assertFalse(op._has_torchbind_op_overload)
def test_torchbind_op_register_fallthrough(self):
TEST_DISPATCH_KEY = torch._C.DispatchKey.AutocastCPU
TEST_DISPATCH_KEY_STR = "AutocastCPU"
for op_packet in self.torch_bind_ops:
op = op_packet.default
ns, _ = torch._library.utils.parse_namespace(op_packet._qualified_op_name)
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
lib.impl(
op.name(), torch.library.fallthrough_kernel, TEST_DISPATCH_KEY_STR
)
self.assertTrue(
torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
op.name(), TEST_DISPATCH_KEY
)
)
def test_torchbind_op_fallthrough_keys_respects_lib_impl(self):
TEST_DISPATCH_KEY = torch._C.DispatchKey.AutogradCPU
TEST_DISPATCH_KEY_STR = "AutogradCPU"
tested = 0
for op_packet in self.torch_bind_ops:
op = op_packet.default
ns, _ = torch._library.utils.parse_namespace(op_packet._qualified_op_name)
if (
not torch._C._dispatch_has_kernel_for_dispatch_key(
op.name(), TEST_DISPATCH_KEY
)
and TEST_DISPATCH_KEY not in op.py_kernels
):
tested += 1
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
lib.impl(
op.name(), lambda *args, **kwargs: args, TEST_DISPATCH_KEY_STR
)
self.assertTrue(TEST_DISPATCH_KEY not in op._fallthrough_keys())
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
lib.impl(
op.name(),
torch.library.fallthrough_kernel,
TEST_DISPATCH_KEY_STR,
)
self.assertTrue(TEST_DISPATCH_KEY in op._fallthrough_keys())
self.assertTrue(tested > 0)
def test_make_fx_schema_checking_script_object(self):
class Model(torch.nn.Module):
def forward(self, tq, x, foo):
torch.ops._TorchScriptTesting.queue_push(foo, x.cos())
return tq
class ModelCallByKW(torch.nn.Module):
def forward(self, tq, x, foo):
torch.ops._TorchScriptTesting.queue_push(x=x.cos(), foo=foo)
return tq
mod = Model()
modkw = ModelCallByKW()
foo = torch.classes._TorchScriptTesting._Foo(10, 20)
x = torch.ones(3, 3)
tq = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
ns = "_TorchScriptTesting"
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
op = torch.ops._TorchScriptTesting.queue_push
lib.impl(op.__name__, torch.library.fallthrough_kernel, "AutogradCPU")
lib.impl(op.__name__, torch.library.fallthrough_kernel, "ADInplaceOrView")
lib.impl(
op.__name__,
torch.library.fallthrough_kernel,
"PythonTLSSnapshot",
)
with self.assertRaisesRegex(
RuntimeError, "is expected to be a FakeScriptObject"
):
_ = make_fx(mod, tracing_mode="fake")(tq, x, foo)
with self.assertRaisesRegex(
RuntimeError, "is expected to be a FakeScriptObject"
):
_ = make_fx(modkw, tracing_mode="fake")(tq, x, foo)
@parametrize("fallthrough_via", ["lib_impl", "py_impl"])
def test_make_fx_tensor_queue_operators(self, fallthrough_via):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, tq, x):
with torch.autocast("cuda", dtype=torch.bfloat16):
torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
x_sin = torch.ops._TorchScriptTesting.queue_pop(
tq
) - torch.ops._TorchScriptTesting.queue_size(tq)
x_cos = torch.ops._TorchScriptTesting.queue_pop(
tq
) + torch.ops._TorchScriptTesting.queue_size(tq)
return x_sin, x_cos, tq
mod = Model()
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq2 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
x = torch.ones(2, 3)
mod(tq1, x)
ops = [
torch.ops._TorchScriptTesting.queue_push,
torch.ops._TorchScriptTesting.queue_pop,
torch.ops._TorchScriptTesting.queue_size,
]
if fallthrough_via == "lib_impl":
ns = "_TorchScriptTesting"
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
for op in ops:
lib.impl(
op.__name__, torch.library.fallthrough_kernel, "AutocastCUDA"
)
gm = make_fx(mod, tracing_mode="fake")(tq1, x)
else:
for op in ops:
op.default.py_impl(torch._C.DispatchKey.AutocastCUDA)(
torch.library.fallthrough_kernel
)
gm = make_fx(mod, tracing_mode="fake")(tq1, x)
for op in ops:
op.default._dispatch_cache.clear()
del op.default.py_kernels[torch._C.DispatchKey.AutocastCUDA]
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
cos = torch.ops.aten.cos.default(arg1_1)
queue_push = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, cos); cos = None
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
queue_push_1 = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, sin); sin = None
queue_pop = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1)
queue_size = torch.ops._TorchScriptTesting.queue_size.default(arg0_1)
sub = torch.ops.aten.sub.Tensor(queue_pop, 1); queue_pop = None
queue_pop_1 = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1)
queue_size_1 = torch.ops._TorchScriptTesting.queue_size.default(arg0_1)
add = torch.ops.aten.add.Tensor(queue_pop_1, 0); queue_pop_1 = None
return (sub, add, arg0_1)""",
)
self._assertEqualSkipScriptObject(gm(tq1, x), mod(tq2, x))
def test_aot_export_tensor_queue_operators(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, tq, x):
torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
x_sin = torch.ops._TorchScriptTesting.queue_pop(
tq
) - torch.ops._TorchScriptTesting.queue_size(tq)
x_cos = torch.ops._TorchScriptTesting.queue_pop(
tq
) + torch.ops._TorchScriptTesting.queue_size(tq)
return x_sin, x_cos, tq
mod = Model()
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
x = torch.ones(2, 3)
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
fake_tq1 = torch._library.fake_class_registry.to_fake_obj(fake_mode, tq1)
fake_x = fake_mode.from_tensor(x)
gm = aot_export_module(mod, (fake_tq1, fake_x), trace_joint=False)[0]
# inputs: token, tq, x
# return: token, x_sin, x_cos, tq
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
cos = torch.ops.aten.cos.default(arg2_1)
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.queue_push.default, arg1_1, cos); arg0_1 = cos = None
getitem = with_effects[0]; with_effects = None
sin = torch.ops.aten.sin.default(arg2_1); arg2_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.queue_push.default, arg1_1, sin); getitem = sin = None
getitem_2 = with_effects_1[0]; with_effects_1 = None
with_effects_2 = torch._higher_order_ops.effects.with_effects(getitem_2, torch.ops._TorchScriptTesting.queue_pop.default, arg1_1); getitem_2 = None
getitem_4 = with_effects_2[0]
getitem_5 = with_effects_2[1]; with_effects_2 = None
with_effects_3 = torch._higher_order_ops.effects.with_effects(getitem_4, torch.ops._TorchScriptTesting.queue_size.default, arg1_1); getitem_4 = None
getitem_6 = with_effects_3[0]; with_effects_3 = None
sub = torch.ops.aten.sub.Tensor(getitem_5, 1); getitem_5 = None
with_effects_4 = torch._higher_order_ops.effects.with_effects(getitem_6, torch.ops._TorchScriptTesting.queue_pop.default, arg1_1); getitem_6 = None
getitem_8 = with_effects_4[0]
getitem_9 = with_effects_4[1]; with_effects_4 = None
with_effects_5 = torch._higher_order_ops.effects.with_effects(getitem_8, torch.ops._TorchScriptTesting.queue_size.default, arg1_1); getitem_8 = None
getitem_10 = with_effects_5[0]; with_effects_5 = None
add = torch.ops.aten.add.Tensor(getitem_9, 0); getitem_9 = None
return (getitem_10, sub, add, arg1_1)""", # noqa: B950
)
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
class TestRegisterFakeClass(TestCase):
def setUp(self):
init_torchbind_implementations()
def tearDown(self):
torch._library.fake_class_registry.global_fake_class_registry.clear()
def test_register_fake_class_no_torch_bind_class(self):
with self.assertRaisesRegex(RuntimeError, "Tried to instantiate class"):
@torch._library.register_fake_class("_TorchScriptTesting::NOT_A_VALID_NAME")
class Invalid:
pass
def test_register_fake_class_no_from_real(self):
with self.assertRaisesRegex(RuntimeError, "define a classmethod from_real"):
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class InvalidFakeFoo:
def __init__(self):
pass
def test_register_fake_class_from_real_not_classmethod(self):
with self.assertRaisesRegex(RuntimeError, "is not a classmethod"):
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class FakeFoo:
def __init__(self, x, y):
self.x = x
self.y = y
def from_real(self, foo_obj):
x, y = foo_obj.__getstate__()
return FakeFoo(x, y)
def test_register_fake_class_valid(self):
class FakeFoo:
def __init__(self, x, y):
self.x = x
self.y = y
@classmethod
def from_real(cls, foo_obj):
x, y = foo_obj.__getstate__()
return cls(x, y)
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
instantiate_parametrized_tests(TestExportTorchbind)
if __name__ == "__main__":
run_tests()