[reland] Cleanup custom op library after each custom_op test (#101450)
Reland of #100980. Original PR was reverted due to internal test
flakiness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101450
Approved by: https://github.com/soulitzer
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index 58dbdbf..ad2c875 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -372,22 +372,32 @@
class TestCustomOp(TestCase):
+ test_ns = '_test_custom_op'
+
+ def tearDown(self):
+ import torch._custom_op
+ keys = list(torch._custom_op.global_registry.keys())
+ for key in keys:
+ if not key.startswith(f'{TestCustomOp.test_ns}::'):
+ continue
+ torch._custom_op.global_registry[key]._destroy()
+
def test_invalid_schemas(self):
# function schmea validation goes through torchgen, so this is just a
# basic test.
with self.assertRaisesRegex(AssertionError, 'Invalid function schema: foo'):
- @custom_op('_torch_testing::foo', "(")
+ @custom_op(f'{TestCustomOp.test_ns}::foo', "(")
def foo(x):
...
def test_name_must_match(self):
with self.assertRaisesRegex(ValueError, 'to have name'):
- @custom_op('_torch_testing::foo', "(Tensor x) -> Tensor")
+ @custom_op(f'{TestCustomOp.test_ns}::foo', "(Tensor x) -> Tensor")
def bar(x):
...
with self.assertRaisesRegex(ValueError, 'to have name'):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def baz(x: Tensor) -> Tensor:
...
@@ -396,107 +406,105 @@
...
with self.assertRaisesRegex(ValueError, 'does not support non-functional'):
- custom_op('_torch_testing::foo', '(Tensor(a!) x) -> Tensor(a)')(foo)
+ custom_op(f'{TestCustomOp.test_ns}::foo', '(Tensor(a!) x) -> Tensor(a)')(foo)
with self.assertRaisesRegex(ValueError, 'does not support view functions'):
- custom_op('_torch_testing::foo', '(Tensor(a) x) -> Tensor(a)')(foo)
+ custom_op(f'{TestCustomOp.test_ns}::foo', '(Tensor(a) x) -> Tensor(a)')(foo)
with self.assertRaisesRegex(ValueError, 'no outputs'):
- custom_op('_torch_testing::foo', '(Tensor x) -> ()')(foo)
+ custom_op(f'{TestCustomOp.test_ns}::foo', '(Tensor x) -> ()')(foo)
with self.assertRaisesRegex(ValueError, 'self'):
- custom_op('_torch_testing::foo', '(Tensor self) -> ()')(foo)
+ custom_op(f'{TestCustomOp.test_ns}::foo', '(Tensor self) -> ()')(foo)
def test_schema_matches_signature(self):
with self.assertRaisesRegex(ValueError, 'signature to match'):
- @custom_op('_torch_testing::blah', '(Tensor y) -> Tensor')
+ @custom_op(f'{TestCustomOp.test_ns}::blah', '(Tensor y) -> Tensor')
def blah(x):
pass
with self.assertRaisesRegex(ValueError, 'signature to match'):
- @custom_op('_torch_testing::blah2', '(Tensor x, *, Tensor y) -> Tensor')
+ @custom_op(f'{TestCustomOp.test_ns}::blah2', '(Tensor x, *, Tensor y) -> Tensor')
def blah2(x, y):
pass
with self.assertRaisesRegex(ValueError, 'signature to match'):
- @custom_op('_torch_testing::blah3', '(Tensor x, *, Tensor w, Tensor z) -> Tensor')
+ @custom_op(f'{TestCustomOp.test_ns}::blah3', '(Tensor x, *, Tensor w, Tensor z) -> Tensor')
def blah3(x, *, y, z):
pass
with self.assertRaisesRegex(ValueError, 'signature to match'):
- @custom_op('_torch_testing::blah4', '(Tensor x, *, Tensor z, Tensor y) -> Tensor')
+ @custom_op(f'{TestCustomOp.test_ns}::blah4', '(Tensor x, *, Tensor z, Tensor y) -> Tensor')
def blah4(x, *, y, z):
pass
with self.assertRaisesRegex(ValueError, 'not supported'):
- @custom_op('_torch_testing::blah5', '(Tensor x) -> Tensor')
+ @custom_op(f'{TestCustomOp.test_ns}::blah5', '(Tensor x) -> Tensor')
def blah5(*args):
pass
with self.assertRaisesRegex(ValueError, 'not supported'):
- @custom_op('_torch_testing::blah6', '(*, Tensor z, Tensor y) -> Tensor')
+ @custom_op(f'{TestCustomOp.test_ns}::blah6', '(*, Tensor z, Tensor y) -> Tensor')
def blah6(**kwargs):
pass
with self.assertRaisesRegex(ValueError, 'default arguments'):
- @custom_op('_torch_testing::blah7', '(Tensor x, *, Tensor y) -> Tensor')
+ @custom_op(f'{TestCustomOp.test_ns}::blah7', '(Tensor x, *, Tensor y) -> Tensor')
def blah7(x=1, *, y):
pass
with self.assertRaisesRegex(ValueError, 'default arguments'):
- @custom_op('_torch_testing::blah8', '(Tensor x, *, Tensor y) -> Tensor')
+ @custom_op(f'{TestCustomOp.test_ns}::blah8', '(Tensor x, *, Tensor y) -> Tensor')
def blah8(x, *, y=1):
pass
# kwonly-arg works
- @custom_op('_torch_testing::blah9', '(Tensor x, *, Tensor y) -> Tensor')
+ @custom_op(f'{TestCustomOp.test_ns}::blah9', '(Tensor x, *, Tensor y) -> Tensor')
def blah9(x, *, y):
pass
- blah9._destroy()
-
def test_unsupported_annotation_categories(self):
with self.assertRaisesRegex(ValueError, 'varargs'):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(*args):
...
del foo
with self.assertRaisesRegex(ValueError, 'varkwargs'):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(**kwargs):
...
del foo
with self.assertRaisesRegex(ValueError, 'must have a type annotation'):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x):
...
del foo
with self.assertRaisesRegex(ValueError, 'default value'):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: Optional[Tensor] = None):
...
del foo
with self.assertRaisesRegex(ValueError, 'default value'):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: Optional[Tensor] = None):
...
del foo
with self.assertRaisesRegex(ValueError, 'either Tensor or a Tuple'):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: Tensor) -> int:
...
del foo
with self.assertRaisesRegex(ValueError, 'either Tensor or a Tuple'):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: Tensor) -> Tuple[Tensor, int]:
...
del foo
with self.assertRaisesRegex(ValueError, 'either Tensor or a Tuple'):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: Tensor) -> Tuple[Tensor, ...]:
...
del foo
@@ -535,7 +543,7 @@
raise AssertionError(f"unsupported param type {typ}")
for typ in torch._custom_op.SUPPORTED_PARAM_TYPES:
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: Tensor, y: typ) -> Tensor:
...
@@ -560,14 +568,14 @@
def test_unsupported_param_types(self):
# Not comprehensive (it doesn't need to be), just a check that our mechanism works
with self.assertRaisesRegex(ValueError, 'unsupported type'):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: Tensor, y: Tuple[Optional[int], ...]) -> Tensor:
...
del foo
with self.assertRaisesRegex(ValueError, 'unsupported type'):
# int[N] in Dispatcher is a bit wild, so we don't try to support it.
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: Tensor, y: Tuple[int, int]) -> Tensor:
...
del foo
@@ -575,13 +583,13 @@
with self.assertRaisesRegex(ValueError, 'unsupported type'):
# We could theoretically support this, but the syntax for suporting
# int[] is Tuple[int, ...]
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: Tensor, y: List[int]) -> Tensor:
...
del foo
with self.assertRaisesRegex(ValueError, 'unsupported type'):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: Tensor, y: Callable) -> Tensor:
...
del foo
@@ -629,10 +637,10 @@
...
for schema in schemas:
- op = custom_op('_torch_testing::foo', schema)(foo)
+ op = custom_op(f'{TestCustomOp.test_ns}::foo', schema)(foo)
op._destroy()
for schema in other_schemas:
- op = custom_op('_torch_testing::bar', schema)(bar)
+ op = custom_op(f'{TestCustomOp.test_ns}::bar', schema)(bar)
op._destroy()
def test_reserved_ns(self):
@@ -653,7 +661,7 @@
CustomOp(None, None, None, None)
def test_lifetime(self):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@@ -665,7 +673,7 @@
# We can't define an op multiple times,
with self.assertRaisesRegex(RuntimeError, 'multiple times'):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@@ -673,14 +681,14 @@
foo._destroy()
# Smoke test
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
foo._destroy()
def test_autograd_notimplemented(self):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@@ -689,7 +697,7 @@
foo(x)
foo._destroy()
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: Tuple[torch.Tensor, ...]) -> torch.Tensor:
...
@@ -699,7 +707,7 @@
foo([y, x])
foo._destroy()
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
...
@@ -710,7 +718,7 @@
foo._destroy()
def test_impl_cpu(self):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@@ -721,10 +729,9 @@
x = torch.randn(3)
result = foo(x)
self.assertEqual(result, foo_cpu(x))
- foo._destroy()
def test_impl_invalid_devices(self):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@@ -747,7 +754,7 @@
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
def test_impl_separate(self):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@@ -770,7 +777,7 @@
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
def test_impl_multiple(self):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@@ -788,7 +795,7 @@
foo._destroy()
def test_impl_meta(self):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
...
@@ -804,7 +811,7 @@
foo._destroy()
def test_duplicate_impl(self):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor, dim: int) -> torch.Tensor:
...
@@ -826,7 +833,7 @@
foo._destroy()
def test_new_data_dependent_symint(self):
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@@ -855,7 +862,7 @@
def test_basic_make_fx(self):
# More serious tests are in our CustomOp opinfo db,
# this one is just a sanity check.
- @custom_op('_torch_testing::foo')
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: torch.Tensor) -> torch.Tensor:
...
@@ -865,7 +872,7 @@
x = torch.randn(3)
gm = make_fx(foo, tracing_mode='symbolic')(x)
- self.assertTrue('_torch_testing.foo' in gm.code)
+ self.assertTrue(f'{TestCustomOp.test_ns}.foo' in gm.code)
foo._destroy()
def test_abstract_registration_location(self):