Back out "Support regex-style matching for Any and Oneof (#82853)" (#83922)
Reviewed By: hl475
Differential Revision: D38945806
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83922
Approved by: https://github.com/hl475
diff --git a/test/test_fx_passes.py b/test/test_fx_passes.py
index 2ff1959..af02a5a 100644
--- a/test/test_fx_passes.py
+++ b/test/test_fx_passes.py
@@ -572,6 +572,7 @@
TestCase(False, True, 0),
]
+
class MultipleOutputsHorizontalPattern:
@staticmethod
def forward(x):
@@ -598,61 +599,6 @@
TestCase(True, True, 0)
]
-class PatternWithPseudoAny:
- @staticmethod
- def forward(x):
- x = x.relu()
- x = x.sigmoid()
-
- y = x.relu()
- y = y + 1
-
- z = y.relu()
- z = z.relu()
-
- return z
-
- @staticmethod
- def pattern(a):
- y = a.relu()
- z = torch.ops.pseudo.any(y)
- return z
-
- test_cases = [
- # match_output, match_placeholder, num_matches
- TestCase(False, False, 3),
- TestCase(True, False, 1),
- TestCase(False, True, 1),
- TestCase(True, True, 0)
- ]
-
-class PatternWithPseudoOneof:
- @staticmethod
- def forward(x):
- x = x.relu()
- x = torch.sigmoid(x)
-
- z = x.relu()
- z = torch.relu(z)
-
- y = x.relu()
- y = y + 1
-
- return y
-
- @staticmethod
- def pattern(a):
- y = a.relu()
- z = torch.ops.pseudo.oneof(y, targets=["torch.sigmoid", "operator.add"])
- return z
-
- test_cases = [
- # match_output, match_placeholder, num_matches
- TestCase(False, False, 2),
- TestCase(True, False, 1),
- TestCase(False, True, 1),
- TestCase(True, True, 0)
- ]
@instantiate_parametrized_tests
class TestFXMatcherUtils(JitTestCase):
@@ -670,9 +616,7 @@
MultipleOutputsMultipleOverlappingMatches,
MultipleOutputsMultipleNonOverlappingMatches,
MultipleOutputsIdenticalAnchor,
- MultipleOutputsHorizontalPattern,
- PatternWithPseudoAny,
- PatternWithPseudoOneof,
+ MultipleOutputsHorizontalPattern
])
def test_subgraph_matcher(self, test_model):
traced = symbolic_trace(test_model.forward)
diff --git a/torch/fx/passes/utils/matcher_utils.py b/torch/fx/passes/utils/matcher_utils.py
index 13d3433..31ae96a 100644
--- a/torch/fx/passes/utils/matcher_utils.py
+++ b/torch/fx/passes/utils/matcher_utils.py
@@ -1,7 +1,6 @@
from dataclasses import dataclass, field
from collections import defaultdict
import copy
-import torch.library
from torch.fx.graph import Graph
from torch.fx.node import Node
from torch.fx._compatibility import compatibility
@@ -10,42 +9,6 @@
__all__ = ['SubgraphMatcher', 'InternalMatch']
-pseudo = torch.library.Library("pseudo", "DEF")
-
-pseudo.define("any() -> ()")
-"""
-pseudo.any is a wildcard node that can be matched with any fx node with arbitrary number of inputs and outputs.
-For example, to match relu followed by one fx node:
- def pattern(a):
- y = a.relu()
- z = torch.ops.pseudo.any(y)
- return z
-"""
-
-pseudo.define("oneof(*, str[] targets) -> ()")
-"""
-pseudo.oneof is a special node that can be matched with a fx node whose target is in the permissible list.
-`targets` must be be a list of qualified name for operators, e.g. ["operator.add", "torch.sigmoid",
-"torch.ops.aten.foo", "torch.ops.prims.bar"]
-
-For example, using following pattern with pseudo.oneof
- def pattern(a):
- y = a.relu()
- z = torch.ops.pseudo.oneof(y, targets=["relu", "torch.sigmoid", "operator.add"])
- return z
-
-It will have 3 matches in the following function
- def forward(y):
- z = y.relu()
- x = z.relu() # first match
-
- x = x.relu()
- x = torch.sigmoid(x) # second match
-
- x = x.relu()
- return x + 1 # third match
-"""
-
@compatibility(is_backward_compatible=False)
@dataclass
class InternalMatch():
@@ -117,18 +80,6 @@
if not self.match_placeholder and pn.op == "placeholder":
return True
- if pn.target == torch.ops.pseudo.any:
- return True
-
- if pn.target == torch.ops.pseudo.oneof:
- permissible_targets: List[str] = pn.kwargs.get("targets", list()) # type: ignore[assignment]
- assert isinstance(permissible_targets, list), \
- "pseudo.oneof(permissible_targets=[\"foo\", \"bar\"]) only accept targets as a list"
- assert len(permissible_targets) > 0, "please specific as least one target for pseudo.oneof"
-
- if gn._pretty_print_target(gn.target) in permissible_targets:
- return True
-
if pn.op == gn.op:
if pn.op == "placeholder" or pn.op == "output":
return True