Fix unconvertible_ops as per #89261 (#89299)
Fixes #89261
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89299
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py
index 51adaef..5d1cdc5 100644
--- a/test/onnx/test_utility_funs.py
+++ b/test/onnx/test_utility_funs.py
@@ -124,6 +124,18 @@
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=12)
self.assertEqual(unconvertible_ops, [])
+ def test_it_returns_empty_list_when_model_contains_supported_inplace_ops(self):
+ class SkipConnectionModule(torch.nn.Module):
+ def forward(self, x):
+ out = x
+ out += x
+ out = torch.nn.functional.relu(out, inplace=True)
+
+ module = SkipConnectionModule()
+ x = torch.randn(4, 4)
+ _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=13)
+ self.assertEqual(unconvertible_ops, [])
+
@parameterized.parameterized_class(
[
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index 67dd719..36d7fdb 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -1333,7 +1333,9 @@
# eliminated in the conversion passes. Users may still see errors caused
# by prim ops even though they don't show up in the list.
continue
- if not registration.registry.is_registered_op(domain_op, opset_version):
+ if not registration.registry.is_registered_op(
+ domain_op.rstrip("_"), opset_version
+ ):
# We consider all registered ops supported, even though some of them are
# only partially supported, because there is not yet a good way to check
# if an op is fully supported.