[jit] move torchbind tests to separate file (#37473)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37473
Test Plan: Imported from OSS
Differential Revision: D21297541
Pulled By: suo
fbshipit-source-id: 65c48094b1f26fbbf251021957257ce04279922b
diff --git a/test/jit/test_torchbind.py b/test/jit/test_torchbind.py
new file mode 100644
index 0000000..b505531
--- /dev/null
+++ b/test/jit/test_torchbind.py
@@ -0,0 +1,241 @@
+import io
+import os
+import sys
+import torch
+from typing import Optional
+
+# Make the helper files in test/ importable
+pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(pytorch_test_dir)
+from torch.testing._internal.jit_utils import JitTestCase
+from torch.testing._internal.common_utils import skipIfRocm
+from torch.testing import FileCheck
+
+if __name__ == "__main__":
+ raise RuntimeError(
+ "This test file is not meant to be run directly, use:\n\n"
+ "\tpython test/test_jit.py TESTNAME\n\n"
+ "instead."
+ )
+
+class TestTorchbind(JitTestCase):
+ @skipIfRocm
+ def test_torchbind(self):
+ def test_equality(f, cmp_key):
+ obj1 = f()
+ obj2 = torch.jit.script(f)()
+ return (cmp_key(obj1), cmp_key(obj2))
+
+ def f():
+ val = torch.classes._TorchScriptTesting._Foo(5, 3)
+ val.increment(1)
+ return val
+ test_equality(f, lambda x: x)
+
+ with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
+ val = torch.classes._TorchScriptTesting._Foo(5, 3)
+ val.increment('foo')
+
+ def f():
+ ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
+ return ss.pop()
+ test_equality(f, lambda x: x)
+
+ def f():
+ ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
+ ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
+ ss1.push(ss2.pop())
+ return ss1.pop() + ss2.pop()
+ test_equality(f, lambda x: x)
+
+ @skipIfRocm
+ def test_torchbind_take_as_arg(self):
+ global StackString # see [local resolution in python]
+ StackString = torch.classes._TorchScriptTesting._StackString
+
+ def foo(stackstring):
+ # type: (StackString)
+ stackstring.push("lel")
+ return stackstring
+
+ script_input = torch.classes._TorchScriptTesting._StackString([])
+ scripted = torch.jit.script(foo)
+ script_output = scripted(script_input)
+ self.assertEqual(script_output.pop(), "lel")
+
+ @skipIfRocm
+ def test_torchbind_return_instance(self):
+ def foo():
+ ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
+ return ss
+
+ scripted = torch.jit.script(foo)
+ # Ensure we are creating the object and calling __init__
+ # rather than calling the __init__wrapper nonsense
+ fc = FileCheck().check('prim::CreateObject()')\
+ .check('prim::CallMethod[name="__init__"]')
+ fc.run(str(scripted.graph))
+ out = scripted()
+ self.assertEqual(out.pop(), "mom")
+ self.assertEqual(out.pop(), "hi")
+
+ @skipIfRocm
+ def test_torchbind_return_instance_from_method(self):
+ def foo():
+ ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
+ clone = ss.clone()
+ ss.pop()
+ return ss, clone
+
+ scripted = torch.jit.script(foo)
+ out = scripted()
+ self.assertEqual(out[0].pop(), "hi")
+ self.assertEqual(out[1].pop(), "mom")
+ self.assertEqual(out[1].pop(), "hi")
+
+ @skipIfRocm
+ def test_torchbind_take_instance_as_method_arg(self):
+ def foo():
+ ss = torch.classes._TorchScriptTesting._StackString(["mom"])
+ ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
+ ss.merge(ss2)
+ return ss
+
+ scripted = torch.jit.script(foo)
+ out = scripted()
+ self.assertEqual(out.pop(), "hi")
+ self.assertEqual(out.pop(), "mom")
+
+ @skipIfRocm
+ def test_torchbind_return_tuple(self):
+ def f():
+ val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
+ return val.return_a_tuple()
+
+ scripted = torch.jit.script(f)
+ tup = scripted()
+ self.assertEqual(tup, (1337.0, 123))
+
+ @skipIfRocm
+ def test_torchbind_save_load(self):
+ def foo():
+ ss = torch.classes._TorchScriptTesting._StackString(["mom"])
+ ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
+ ss.merge(ss2)
+ return ss
+
+ scripted = torch.jit.script(foo)
+ self.getExportImportCopy(scripted)
+
+ @skipIfRocm
+ def test_torchbind_lambda_method(self):
+ def foo():
+ ss = torch.classes._TorchScriptTesting._StackString(["mom"])
+ return ss.top()
+
+ scripted = torch.jit.script(foo)
+ self.assertEqual(scripted(), "mom")
+
+ @skipIfRocm
+ def test_torchbind_class_attribute(self):
+ class FooBar1234(torch.nn.Module):
+ def __init__(self):
+ super(FooBar1234, self).__init__()
+ self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
+
+ def forward(self):
+ return self.f.top()
+
+ inst = FooBar1234()
+ scripted = torch.jit.script(inst)
+ eic = self.getExportImportCopy(scripted)
+ assert eic() == "deserialized"
+ for expected in ["deserialized", "was", "i"]:
+ assert eic.f.pop() == expected
+
+ @skipIfRocm
+ def test_torchbind_getstate(self):
+ class FooBar4321(torch.nn.Module):
+ def __init__(self):
+ super(FooBar4321, self).__init__()
+ self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
+
+ def forward(self):
+ return self.f.top()
+
+ inst = FooBar4321()
+ scripted = torch.jit.script(inst)
+ eic = self.getExportImportCopy(scripted)
+ # NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
+ # return {1, 3, 3, 7}. I tried to make this actually depend on the
+ # values at instantiation in the test with some transformation, but
+ # because it seems we serialize/deserialize multiple times, that
+ # transformation isn't as you would it expect it to be.
+ assert eic() == 7
+ for expected in [7, 3, 3, 1]:
+ assert eic.f.pop() == expected
+
+ @skipIfRocm
+ def test_torchbind_tracing(self):
+ class TryTracing(torch.nn.Module):
+ def __init__(self):
+ super(TryTracing, self).__init__()
+ self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
+
+ def forward(self):
+ return torch.ops._TorchScriptTesting.take_an_instance(self.f)
+
+ traced = torch.jit.trace(TryTracing(), ())
+ self.assertEqual(torch.zeros(4, 4), traced())
+
+ @skipIfRocm
+ def test_torchbind_tracing_nested(self):
+ class TryTracingNest(torch.nn.Module):
+ def __init__(self):
+ super(TryTracingNest, self).__init__()
+ self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
+
+ class TryTracing123(torch.nn.Module):
+ def __init__(self):
+ super(TryTracing123, self).__init__()
+ self.nest = TryTracingNest()
+
+ def forward(self):
+ return torch.ops._TorchScriptTesting.take_an_instance(self.nest.f)
+
+ traced = torch.jit.trace(TryTracing123(), ())
+ self.assertEqual(torch.zeros(4, 4), traced())
+
+ @skipIfRocm
+ def test_torchbind_pickle_serialization(self):
+ nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
+ b = io.BytesIO()
+ torch.save(nt, b)
+ b.seek(0)
+ nt_loaded = torch.load(b)
+ for exp in [7, 3, 3, 1]:
+ self.assertEqual(nt_loaded.pop(), exp)
+
+ @skipIfRocm
+ def test_torchbind_instantiate_missing_class(self):
+ with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class \'foo.IDontExist\', but it does not exist!'):
+ torch.classes.foo.IDontExist(3, 4, 5)
+
+ @skipIfRocm
+ def test_torchbind_optional_explicit_attr(self):
+ class TorchBindOptionalExplicitAttr(torch.nn.Module):
+ foo : Optional[torch.classes._TorchScriptTesting._StackString]
+
+ def __init__(self):
+ super().__init__()
+ self.foo = torch.classes._TorchScriptTesting._StackString(["test"])
+
+ def forward(self) -> str:
+ foo_obj = self.foo
+ if foo_obj is not None:
+ return foo_obj.pop()
+ else:
+ return '<None>'
+
+ mod = TorchBindOptionalExplicitAttr()
+ scripted = torch.jit.script(mod)
diff --git a/test/test_jit.py b/test/test_jit.py
index ee734e9..c23dad1 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -23,6 +23,7 @@
from jit.test_save_load import TestSaveLoad # noqa: F401
from jit.test_python_ir import TestPythonIr # noqa: F401
from jit.test_functional_blocks import TestFunctionalBlocks # noqa: F401
+from jit.test_torchbind import TestTorchbind # noqa: F401
# Torch
from torch import Tensor
@@ -4579,114 +4580,6 @@
self.assertEqual(7, w(3))
self.assertFalse("training" in w.state_dict())
- @skipIfRocm
- def test_torchbind(self):
- def test_equality(f, cmp_key):
- obj1 = f()
- obj2 = torch.jit.script(f)()
- return (cmp_key(obj1), cmp_key(obj2))
-
- def f():
- val = torch.classes._TorchScriptTesting._Foo(5, 3)
- val.increment(1)
- return val
- test_equality(f, lambda x: x)
-
- with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
- val = torch.classes._TorchScriptTesting._Foo(5, 3)
- val.increment('foo')
-
- def f():
- ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
- return ss.pop()
- test_equality(f, lambda x: x)
-
- def f():
- ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
- ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
- ss1.push(ss2.pop())
- return ss1.pop() + ss2.pop()
- test_equality(f, lambda x: x)
-
- @skipIfRocm
- def test_torchbind_take_as_arg(self):
- global StackString # see [local resolution in python]
- StackString = torch.classes._TorchScriptTesting._StackString
-
- def foo(stackstring):
- # type: (StackString)
- stackstring.push("lel")
- return stackstring
-
- script_input = torch.classes._TorchScriptTesting._StackString([])
- scripted = torch.jit.script(foo)
- script_output = scripted(script_input)
- self.assertEqual(script_output.pop(), "lel")
-
- @skipIfRocm
- def test_torchbind_return_instance(self):
- def foo():
- ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
- return ss
-
- scripted = torch.jit.script(foo)
- # Ensure we are creating the object and calling __init__
- # rather than calling the __init__wrapper nonsense
- fc = FileCheck().check('prim::CreateObject()')\
- .check('prim::CallMethod[name="__init__"]')
- fc.run(str(scripted.graph))
- out = scripted()
- self.assertEqual(out.pop(), "mom")
- self.assertEqual(out.pop(), "hi")
-
- @skipIfRocm
- def test_torchbind_return_instance_from_method(self):
- def foo():
- ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
- clone = ss.clone()
- ss.pop()
- return ss, clone
-
- scripted = torch.jit.script(foo)
- out = scripted()
- self.assertEqual(out[0].pop(), "hi")
- self.assertEqual(out[1].pop(), "mom")
- self.assertEqual(out[1].pop(), "hi")
-
- @skipIfRocm
- def test_torchbind_take_instance_as_method_arg(self):
- def foo():
- ss = torch.classes._TorchScriptTesting._StackString(["mom"])
- ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
- ss.merge(ss2)
- return ss
-
- scripted = torch.jit.script(foo)
- out = scripted()
- self.assertEqual(out.pop(), "hi")
- self.assertEqual(out.pop(), "mom")
-
- @skipIfRocm
- def test_torchbind_return_tuple(self):
- def f():
- val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
- return val.return_a_tuple()
-
- scripted = torch.jit.script(f)
- tup = scripted()
- self.assertEqual(tup, (1337.0, 123))
-
- @skipIfRocm
- def test_torchbind_save_load(self):
- def foo():
- ss = torch.classes._TorchScriptTesting._StackString(["mom"])
- ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
- ss.merge(ss2)
- return ss
-
- scripted = torch.jit.script(foo)
- self.getExportImportCopy(scripted)
-
def test_class_as_attribute(self):
@torch.jit.script
class Foo321(object):
@@ -4706,124 +4599,6 @@
x = torch.rand(3, 4)
self.assertEqual(scripted(x), eic(x))
- @skipIfRocm
- def test_torchbind_lambda_method(self):
- def foo():
- ss = torch.classes._TorchScriptTesting._StackString(["mom"])
- return ss.top()
-
- scripted = torch.jit.script(foo)
- self.assertEqual(scripted(), "mom")
-
- @skipIfRocm
- def test_torchbind_class_attribute(self):
- class FooBar1234(torch.nn.Module):
- def __init__(self):
- super(FooBar1234, self).__init__()
- self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
-
- def forward(self):
- return self.f.top()
-
- inst = FooBar1234()
- scripted = torch.jit.script(inst)
- eic = self.getExportImportCopy(scripted)
- assert eic() == "deserialized"
- for expected in ["deserialized", "was", "i"]:
- assert eic.f.pop() == expected
-
- @skipIfRocm
- def test_torchbind_getstate(self):
- class FooBar4321(torch.nn.Module):
- def __init__(self):
- super(FooBar4321, self).__init__()
- self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
-
- def forward(self):
- return self.f.top()
-
- inst = FooBar4321()
- scripted = torch.jit.script(inst)
- eic = self.getExportImportCopy(scripted)
- # NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
- # return {1, 3, 3, 7}. I tried to make this actually depend on the
- # values at instantiation in the test with some transformation, but
- # because it seems we serialize/deserialize multiple times, that
- # transformation isn't as you would it expect it to be.
- assert eic() == 7
- for expected in [7, 3, 3, 1]:
- assert eic.f.pop() == expected
-
- @skipIfRocm
- def test_torchbind_tracing(self):
- class TryTracing(torch.nn.Module):
- def __init__(self):
- super(TryTracing, self).__init__()
- self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
-
- def forward(self):
- return torch.ops._TorchScriptTesting.take_an_instance(self.f)
-
- traced = torch.jit.trace(TryTracing(), ())
- self.assertEqual(torch.zeros(4, 4), traced())
-
- @skipIfRocm
- def test_torchbind_tracing_nested(self):
- class TryTracingNest(torch.nn.Module):
- def __init__(self):
- super(TryTracingNest, self).__init__()
- self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
-
- class TryTracing123(torch.nn.Module):
- def __init__(self):
- super(TryTracing123, self).__init__()
- self.nest = TryTracingNest()
-
- def forward(self):
- return torch.ops._TorchScriptTesting.take_an_instance(self.nest.f)
-
- traced = torch.jit.trace(TryTracing123(), ())
- self.assertEqual(torch.zeros(4, 4), traced())
-
- @skipIfRocm
- def test_torchbind_pickle_serialization(self):
- nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
- b = io.BytesIO()
- torch.save(nt, b)
- b.seek(0)
- nt_loaded = torch.load(b)
- for exp in [7, 3, 3, 1]:
- self.assertEqual(nt_loaded.pop(), exp)
-
- @skipIfRocm
- def test_torchbind_instantiate_missing_class(self):
- with self.assertRaisesRegex(RuntimeError, 'Tried to instantiate class \'foo.IDontExist\', but it does not exist!'):
- torch.classes.foo.IDontExist(3, 4, 5)
-
- @skipIfRocm
- def test_torchbind_optional_explicit_attr(self):
- class TorchBindOptionalExplicitAttr(torch.nn.Module):
- foo : Optional[torch.classes._TorchScriptTesting._StackString]
-
- def __init__(self):
- super().__init__()
- self.foo = torch.classes._TorchScriptTesting._StackString(["test"])
-
- def forward(self) -> str:
- foo_obj = self.foo
- if foo_obj is not None:
- return foo_obj.pop()
- else:
- return '<None>'
-
- mod = TorchBindOptionalExplicitAttr()
- scripted = torch.jit.script(mod)
-
- @skipIfRocm
- def test_torchbind_str(self):
- foo = torch.classes._TorchScriptTesting._StackString(["foo", "bar", "baz"])
- self.assertEqual(str(foo), "[foo, bar, baz]")
-
def test_module_str(self):
class Foo(torch.nn.Module):
def forward(self, x):
@@ -4832,12 +4607,6 @@
f = torch.jit.script(Foo())
self.assertEqual('ScriptObject', str(f._c))
- @skipIfRocm
- def test_torchbind_magic_unimplemented(self):
- foo = torch.classes._TorchScriptTesting._StackString(["foo", "bar", "baz"])
- with self.assertRaises(NotImplementedError):
- foo[3]
-
def _test_lower_graph_impl(self, model, data):
model.qconfig = torch.quantization.default_qconfig
model = torch.quantization.prepare(model)