[dynamo] Refactor COMPARE_OP and comparison builtins (#122043)

This removes the duplicate handling of comparison ops between symbolic_convert and bultin and refactors the handling to use the binop infrastructure.  This change regresses overheads a bit, but this is fixed in the next PR.

New test skips are variants of `type(e) is np.ndarray` previously falling back to eager.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122043
Approved by: https://github.com/anijain2305
ghstack dependencies: #122039
diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv
index 09b5510..a5ff846 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv
@@ -86,7 +86,7 @@
 
 
 
-detectron2_fcos_r_50_fpn,pass,35
+detectron2_fcos_r_50_fpn,pass,94
 
 
 
diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv
index cca1164..1527a2c 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv
@@ -54,47 +54,47 @@
 
 
 
-detectron2_fasterrcnn_r_101_c4,pass,51
+detectron2_fasterrcnn_r_101_c4,pass,164
 
 
 
-detectron2_fasterrcnn_r_101_dc5,pass,51
+detectron2_fasterrcnn_r_101_dc5,pass,163
 
 
 
-detectron2_fasterrcnn_r_101_fpn,pass,55
+detectron2_fasterrcnn_r_101_fpn,pass,172
 
 
 
-detectron2_fasterrcnn_r_50_c4,pass,51
+detectron2_fasterrcnn_r_50_c4,pass,113
 
 
 
-detectron2_fasterrcnn_r_50_dc5,pass,51
+detectron2_fasterrcnn_r_50_dc5,pass,112
 
 
 
-detectron2_fasterrcnn_r_50_fpn,pass,55
+detectron2_fasterrcnn_r_50_fpn,pass,121
 
 
 
-detectron2_fcos_r_50_fpn,pass,38
+detectron2_fcos_r_50_fpn,pass,97
 
 
 
-detectron2_maskrcnn_r_101_c4,fail_accuracy,66
+detectron2_maskrcnn_r_101_c4,pass,182
 
 
 
-detectron2_maskrcnn_r_101_fpn,pass,73
+detectron2_maskrcnn_r_101_fpn,pass,192
 
 
 
-detectron2_maskrcnn_r_50_c4,pass,66
+detectron2_maskrcnn_r_50_c4,pass,131
 
 
 
-detectron2_maskrcnn_r_50_fpn,pass,73
+detectron2_maskrcnn_r_50_fpn,pass,141
 
 
 
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv
index ae1f61d..5b9f79f 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv
@@ -86,7 +86,7 @@
 
 
 
-detectron2_fcos_r_50_fpn,pass,35
+detectron2_fcos_r_50_fpn,pass,94
 
 
 
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv
index 4753fb4..c701671 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv
@@ -54,7 +54,7 @@
 
 
 
-detectron2_fcos_r_50_fpn,pass,38
+detectron2_fcos_r_50_fpn,pass,97
 
 
 
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv
index 02881ad..eacccb9 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv
@@ -86,7 +86,7 @@
 
 
 
-detectron2_fcos_r_50_fpn,pass,36
+detectron2_fcos_r_50_fpn,pass,95
 
 
 
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv
index 09b5510..a5ff846 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv
@@ -86,7 +86,7 @@
 
 
 
-detectron2_fcos_r_50_fpn,pass,35
+detectron2_fcos_r_50_fpn,pass,94
 
 
 
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv
index bd93461..dac50db 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv
@@ -86,7 +86,7 @@
 
 
 
-detectron2_fcos_r_50_fpn,pass,36
+detectron2_fcos_r_50_fpn,pass,95
 
 
 
diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py
index 5f3343b..4f3deb8 100644
--- a/test/dynamo/test_functions.py
+++ b/test/dynamo/test_functions.py
@@ -1460,6 +1460,26 @@
         return par_mul(x)
 
     @make_test
+    def test_list_add_then_mutate(x):
+        my_list = [1, x]
+        y = x / 4.0
+        my_list = my_list + [x / 2.0, 4]
+        my_list.append(y)
+        return sum(my_list)
+
+    @make_test
+    def test_list_expand_lhs(x):
+        return sum(4 * [x])
+
+    @make_test
+    def test_in_not_in(x):
+        mylist = [1, 2, 3, 4, 5, x]
+        myotherlist = [1, 2, 3, 4, 5]
+        assert 3 in mylist
+        assert 6 not in myotherlist
+        return sum(mylist)
+
+    @make_test
     def test_partials_udf_kwarg(x):
         par_mul = functools.partial(udf_mul, y=torch.ones(10, 10))
         return par_mul(x)
diff --git a/test/dynamo_expected_failures/TestFX.test_pytree_concrete b/test/dynamo_skips/TestHistogramdd.test_bins_array
similarity index 100%
copy from test/dynamo_expected_failures/TestFX.test_pytree_concrete
copy to test/dynamo_skips/TestHistogramdd.test_bins_array
diff --git a/test/dynamo_expected_failures/TestFX.test_pytree_concrete b/test/dynamo_skips/TestSqueeze.test_squeeze_type
similarity index 100%
rename from test/dynamo_expected_failures/TestFX.test_pytree_concrete
rename to test/dynamo_skips/TestSqueeze.test_squeeze_type
diff --git a/test/dynamo_expected_failures/TestFX.test_pytree_concrete b/test/dynamo_skips/TestSubscripting.test_test_zero_rank
similarity index 100%
copy from test/dynamo_expected_failures/TestFX.test_pytree_concrete
copy to test/dynamo_skips/TestSubscripting.test_test_zero_rank
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index f37dcbb..57e45a5 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -99,17 +99,11 @@
     UnknownVariable,
 )
 from .variables.nn_module import NNModuleVariable
-from .variables.tensor import (
-    supported_comparison_ops,
-    supported_const_comparison_ops,
-    SymNodeVariable,
-    TensorVariable,
-)
+from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
 from .variables.user_defined import (
     RemovableHandleVariable,
     UserDefinedClassVariable,
     UserDefinedObjectVariable,
-    UserDefinedVariable,
 )
 
 log = logging.getLogger(__name__)
@@ -117,6 +111,17 @@
 trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
 trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
 tls = threading.local()
+compare_op_handlers: Dict[str, Any] = {
+    k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items()
+}
+handle_contains = BuiltinVariable(operator.contains).call_function
+handle_not = BuiltinVariable(operator.not_).call_function
+compare_op_handlers["in"] = lambda tx, args, _: handle_contains(
+    tx, [*reversed(args)], {}
+)
+compare_op_handlers["not in"] = lambda tx, args, _: handle_not(
+    tx, [handle_contains(tx, [*reversed(args)], {})], {}
+)
 
 
 @dataclasses.dataclass
@@ -1200,49 +1205,7 @@
             unimplemented(f"FOR_ITER {typestr(it)}")
 
     def COMPARE_OP(self, inst):
-        left, right = self.popn(2)
-        op = inst.argval
-        if op == "in" or op == "not in":
-            self.push(right.call_method(self, "__contains__", [left], {}))
-            if op == "not in":
-                self.UNARY_NOT(inst)
-            return
-
-        if right.is_python_constant():
-            if left.is_python_constant():
-                # constant fold
-                return self.push(
-                    ConstantVariable(
-                        supported_comparison_ops[op](
-                            left.as_python_constant(), right.as_python_constant()
-                        ),
-                    )
-                )
-            elif (
-                op in supported_const_comparison_ops
-                and right.as_python_constant() is None
-                and isinstance(
-                    left,
-                    (
-                        TensorVariable,
-                        SymNodeVariable,
-                        NNModuleVariable,
-                        BaseListVariable,
-                        UserDefinedVariable,
-                        BaseUserFunctionVariable,
-                        ConstDictVariable,
-                    ),
-                )
-            ):
-                # <non-None> is None
-                return self.push(
-                    ConstantVariable(supported_const_comparison_ops[op](object(), None))
-                )
-        self.push(
-            BuiltinVariable(supported_comparison_ops[op]).call_function(
-                self, [left, right], {}
-            )
-        )
+        self.push(compare_op_handlers[inst.argval](self, self.popn(2), {}))
 
     def GET_ITER(self, inst):
         self.call_function(BuiltinVariable(iter), [self.pop()], {})
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index 5fedf48..4a14d4b 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -38,7 +38,7 @@
     proxy_args_kwargs,
     tensortype_to_dtype,
 )
-from .base import MutableLocal, typestr, VariableTracker
+from .base import MutableLocal, VariableTracker
 from .constant import ConstantVariable
 from .ctx_manager import EventVariable, StreamVariable
 from .dicts import (
@@ -58,6 +58,7 @@
 )
 from .tensor import (
     FakeItemVariable,
+    supported_comparison_ops,
     SymNodeVariable,
     TensorVariable,
     UnspecializedPythonVariable,
@@ -167,6 +168,9 @@
             operator.ior,
             operator.index,
         }
+        from .tensor import supported_comparison_ops
+
+        fns.update(supported_comparison_ops.values())
         fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
         return fns
 
@@ -262,7 +266,17 @@
         # Multiple dispatch mechanism defining custom binop behavior for certain type
         # combinations. Handlers are attempted in order, and will be used if the type checks
         # match. They are expected to have the signature:
-        # fn(tx, arg0: VariableTracker, arg1: VariableTracker, options) -> VariableTracker
+        # fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker
+        from .dicts import DictKeys, SetVariable
+        from .functions import BaseUserFunctionVariable, UserFunctionVariable
+        from .nn_module import NNModuleVariable
+        from .tensor import supported_const_comparison_ops
+        from .torch import BaseTorchVariable
+        from .user_defined import (
+            UserDefinedClassVariable,
+            UserDefinedObjectVariable,
+            UserDefinedVariable,
+        )
 
         # Override table contains: op_fn -> [list of handlers]
         op_handlers = {}
@@ -280,7 +294,7 @@
                 tx,
                 a,
                 b,
-                options,
+                *,
                 forward_name=forward_name,
                 reverse_name=reverse_name,
             ):
@@ -310,9 +324,7 @@
                 ((VariableTracker, UserDefinedVariable), user_defined_handler)
             )
 
-            def user_defined_inplace_handler(
-                tx, a, b, options, forward_name=inplace_name
-            ):
+            def user_defined_inplace_handler(tx, a, b, *, forward_name=inplace_name):
                 return a.call_method(tx, forward_name, [b], {})
 
             op_handlers[in_place_op].append(
@@ -323,7 +335,7 @@
             )
 
             # Dynamic shape args
-            def dynamic_handler(tx, a, b, options, fn=op):
+            def dynamic_handler(tx, a, b, *, fn=op):
                 from .builder import wrap_fx_proxy
 
                 return wrap_fx_proxy(
@@ -331,7 +343,6 @@
                     tx.output.create_proxy(
                         "call_function", fn, *proxy_args_kwargs([a, b], {})
                     ),
-                    **options,
                 )
 
             op_handlers[op].append(
@@ -352,11 +363,11 @@
         # Special cases - lower precedence but still prefer these over constant folding
 
         # List-like addition (e.g. [1, 2] + [3, 4])
-        def tuple_add_handler(tx, a, b, options):
-            return TupleVariable(a.items + list(b.unpack_var_sequence(tx)), **options)
+        def tuple_add_handler(tx, a, b):
+            return TupleVariable([*a.items, *b.unpack_var_sequence(tx)])
 
-        def size_add_handler(tx, a, b, options):
-            return SizeVariable(a.items + list(b.unpack_var_sequence(tx)), **options)
+        def size_add_handler(tx, a, b):
+            return SizeVariable([*a.items, *b.unpack_var_sequence(tx)])
 
         list_like_addition_handlers = [
             # NB: Prefer the tuple-specific logic over base logic because of
@@ -376,18 +387,27 @@
             ),
             (
                 (ConstantVariable, TupleVariable),
-                lambda tx, a, b, options: TupleVariable(
-                    list(a.unpack_var_sequence(tx)) + b.items, **options
+                lambda tx, a, b: TupleVariable(
+                    [*a.unpack_var_sequence(tx), *b.items],
+                ),
+            ),
+            (
+                (
+                    ListVariable,
+                    (BaseListVariable, ConstantVariable, ListIteratorVariable),
+                ),
+                lambda tx, a, b: ListVariable(
+                    [*a.items, *b.unpack_var_sequence(tx)], mutable_local=MutableLocal()
                 ),
             ),
             (
                 (BaseListVariable, BaseListVariable),
-                lambda tx, a, b, options: type(a)(a.items + b.items, **options),
+                lambda tx, a, b: type(a)([*a.items, *b.items]),
             ),
         ]
         op_handlers[operator.add].extend(list_like_addition_handlers)
 
-        def list_iadd_handler(tx, a, b, _):
+        def list_iadd_handler(tx, a, b):
             if not a.mutable_local or not b.has_unpack_var_sequence(tx):
                 # Handler doesn't apply
                 return None
@@ -414,30 +434,169 @@
         op_handlers[operator.iadd].extend(list_like_iadd_handlers)
 
         # List-like expansion (e.g. [1, 2, 3] * 3)
-        def expand_list_like(tx, lst, const, options):
+        def expand_list_like(tx, lst, const):
+            if isinstance(lst, ConstantVariable):
+                lst, const = const, lst
             return lst.__class__(
                 items=lst.items * const.as_python_constant(),
                 mutable_local=MutableLocal(),
-                **options,
             )
 
         list_like_expansion_handlers = [
             ((ListVariable, ConstantVariable), expand_list_like),
             ((TupleVariable, ConstantVariable), expand_list_like),
-            (
-                (ConstantVariable, ListVariable),
-                lambda tx, a, b, options: expand_list_like(tx, b, a, options),
-            ),
-            (
-                (ConstantVariable, TupleVariable),
-                lambda tx, a, b, options: expand_list_like(tx, b, a, options),
-            ),
+            ((ConstantVariable, ListVariable), expand_list_like),
+            ((ConstantVariable, TupleVariable), expand_list_like),
         ]
         op_handlers[operator.mul].extend(list_like_expansion_handlers)
 
-        for key in op_handlers.keys():
-            # insert a mutable cache into each entry
-            op_handlers[key] = (op_handlers[key], dict())
+        size_or_tuple = (SizeVariable, TupleVariable)
+        has_set_items = (SetVariable, DictKeys)
+
+        def create_cmp_op_handlers(op):
+            def compare_by_value(tx, a, b):
+                return ConstantVariable(op(a.value, b.value))
+
+            result = [((ConstantVariable, ConstantVariable), compare_by_value)]
+
+            if op in supported_const_comparison_ops.values():
+                # Tensor is None, List is not None, etc
+                none_result = op(object(), None)
+                if op.__name__.startswith("is_"):
+
+                    def never(tx, a, b):
+                        return ConstantVariable(none_result)
+
+                    obj_op_none = never
+                    none_op_obj = never
+                else:
+
+                    def obj_op_none(tx, a, b: ConstantVariable):
+                        if b.value is None or b.value is True or b.value is False:
+                            return ConstantVariable(none_result)
+
+                    def none_op_obj(tx, a: ConstantVariable, b):
+                        if a.value is None or a.value is True or a.value is False:
+                            return ConstantVariable(none_result)
+
+                types_that_are_never_none = (
+                    TensorVariable,
+                    SymNodeVariable,
+                    NNModuleVariable,
+                    BaseListVariable,
+                    UserDefinedVariable,
+                    BaseUserFunctionVariable,
+                    ConstDictVariable,
+                    BaseTorchVariable,
+                )
+                result.extend(
+                    [
+                        (
+                            (types_that_are_never_none, ConstantVariable),
+                            obj_op_none,
+                        ),
+                        (
+                            (ConstantVariable, types_that_are_never_none),
+                            none_op_obj,
+                        ),
+                    ]
+                )
+
+            def list_compare_nocheck(tx, left, right):
+                return BaseListVariable.list_compare(tx, op, left, right)
+
+            def list_compare_check(tx, left, right):
+                if type(left) is not type(
+                    right
+                ):  # Mismatch in BaseListVariable subclasses
+                    unimplemented(f"{op.__name__}({left}, {right})")
+                return BaseListVariable.list_compare(tx, op, left, right)
+
+            def compare_set_items(tx, left, right):
+                return ConstantVariable(op(left.set_items, right.set_items))
+
+            op_var = BuiltinVariable(op)
+            result.extend(
+                [
+                    (
+                        (
+                            (UserFunctionVariable, BuiltinVariable),
+                            (UserFunctionVariable, BuiltinVariable),
+                        ),
+                        lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)),
+                    ),
+                    (
+                        (
+                            NNModuleVariable,
+                            NNModuleVariable,
+                        ),
+                        lambda tx, a, b: ConstantVariable(
+                            op(
+                                tx.output.get_submodule(a.module_key),
+                                tx.output.get_submodule(b.module_key),
+                            )
+                        ),
+                    ),
+                    ((size_or_tuple, size_or_tuple), list_compare_nocheck),
+                    (
+                        (variables.BaseListVariable, variables.BaseListVariable),
+                        list_compare_check,
+                    ),
+                    ((has_set_items, has_set_items), compare_set_items),
+                    # TODO(jansel): UserDefinedObjectVariable is wrong and could invoke user code
+                    (
+                        (UserDefinedObjectVariable, UserDefinedObjectVariable),
+                        compare_by_value,
+                    ),
+                    (
+                        (UserDefinedClassVariable, UserDefinedClassVariable),
+                        compare_by_value,
+                    ),
+                    (
+                        (
+                            (StreamVariable, EventVariable, ConstantVariable),
+                            (StreamVariable, EventVariable, ConstantVariable),
+                        ),
+                        compare_by_value,
+                    ),
+                    (
+                        (TensorVariable, VariableTracker),
+                        op_var._comparison_with_tensor,
+                    ),
+                    (
+                        (VariableTracker, TensorVariable),
+                        op_var._comparison_with_tensor,
+                    ),
+                    (
+                        (SymNodeVariable, VariableTracker),
+                        op_var._comparison_with_symnode,
+                    ),
+                    (
+                        (VariableTracker, SymNodeVariable),
+                        op_var._comparison_with_symnode,
+                    ),
+                ]
+            )
+
+            if op.__name__.startswith("is_"):
+
+                def handle_is(tx, left, right):
+                    # If the two objects are of different type, we can safely return False
+                    # and True for `is` and `is not`, respectively
+                    if type(left) is not type(right):
+                        return ConstantVariable.create(op.__name__ != "is_")
+
+                result.append(((VariableTracker, VariableTracker), handle_is))
+
+            return result
+
+        for op in supported_comparison_ops.values():
+            assert callable(op)
+            assert op not in op_handlers
+            op_handlers[op] = create_cmp_op_handlers(op)
+
+        for op in op_handlers.keys():
+            op_handlers[op] = (op_handlers[op], dict())
         return op_handlers
 
     @staticmethod
@@ -452,13 +611,26 @@
         if hit is not False:
             return hit
 
-        # Return first handler that matches the type checks
+        a_type = type(a)
+        b_type = type(b)
+        matches = []
         for (type1, type2), handler in handlers:
-            if isinstance(a, type1) and isinstance(b, type2):
-                cache[cache_key] = handler
-                return handler
-        cache[cache_key] = None
-        return None
+            if issubclass(a_type, type1) and issubclass(b_type, type2):
+                matches.append(handler)
+
+        if not matches:
+            result = None
+        elif len(matches) == 1:
+            result = matches[0]
+        else:
+
+            def result(*args):
+                for fn in matches:
+                    rv = fn(*args)
+                    if rv:
+                        return rv
+
+        return result
 
     def can_insert_in_graph(self):
         return self.fn in self._fx_graph_functions()
@@ -666,7 +838,7 @@
             # Try to find a handler for the arg types; otherwise, fall through to constant handler
             binop_handler = BuiltinVariable._find_binop_handler(fn, args[0], args[1])
             if binop_handler:
-                res = binop_handler(tx, args[0], args[1], {})
+                res = binop_handler(tx, args[0], args[1])
                 if res is not None:
                     return res
 
@@ -1511,150 +1683,63 @@
     def call_deepcopy(self, tx, x):
         unimplemented(f"copy.deepcopy {repr(x)}")
 
-    def _comparison(self, tx, left, right):
-        """
-        Used to implement comparison operators for different types.
-        For example, list1 < list2 is implemented differently from tensor1 < tensor2
-        """
-        from . import (
-            BaseListVariable,
-            ConstantVariable,
-            NNModuleVariable,
-            TensorVariable,
-            UserDefinedObjectVariable,
-            UserFunctionVariable,
-        )
-        from .lists import SizeVariable
-        from .tensor import (
-            supported_const_comparison_op_values,
-            supported_tensor_comparison_op_values,
-        )
+    def _comparison_with_tensor(self, tx, left, right):
+        from .builder import wrap_fx_proxy_cls
+        from .tensor import supported_tensor_comparison_op_values
 
         op = self.fn
 
-        def _unimplemented():
-            unimplemented(f"comparison {typestr(left)} {op} {typestr(right)}")
-
-        if (
-            all(
-                isinstance(x, (NNModuleVariable, ConstantVariable))
-                for x in [left, right]
-            )
-            and op in supported_const_comparison_op_values
-        ):
-            left = (
-                tx.output.get_submodule(left.module_key)
-                if isinstance(left, NNModuleVariable)
-                else left.as_python_constant()
-            )
-            right = (
-                tx.output.get_submodule(right.module_key)
-                if isinstance(right, NNModuleVariable)
-                else right.as_python_constant()
-            )
-            return ConstantVariable.create(op(left, right))
-
-        if isinstance(left, UserFunctionVariable):
-            if op not in supported_const_comparison_op_values:
-                _unimplemented()
-            if not isinstance(right, UserFunctionVariable):
-                _unimplemented()
-            return ConstantVariable.create(op(left.fn, right.fn))
-
-        # Note, we have a rare BaseListVariable subtype mismatch with valid comparison
-        # x = torch.randn([3, 3])
-        # x.size() == (3, 3) # True
-        # (3, 3) == x.size() # True
-        if isinstance(left, (SizeVariable, TupleVariable)) and isinstance(
-            right, (TupleVariable, SizeVariable)
-        ):
-            return BaseListVariable.list_compare(tx, op, left, right)
-
-        if isinstance(left, BaseListVariable):
-            if not type(left) == type(right):  # Mismatch in BaseListVariable subclasses
-                _unimplemented()
-            return BaseListVariable.list_compare(tx, op, left, right)
-
-        # If they implement set semantics (e.g. SetVariable or DictKeys)
-        if hasattr(left, "set_items") and hasattr(right, "set_items"):
-            return ConstantVariable.create(op(left.set_items, right.set_items))
-
-        if isinstance(left, TensorVariable) or isinstance(right, TensorVariable):
-            from .builder import wrap_fx_proxy_cls
-
-            if op in [operator.is_, operator.is_not]:
-                is_result = (
-                    isinstance(left, TensorVariable)
-                    and isinstance(right, TensorVariable)
-                    and id(extract_fake_example_value(left.as_proxy().node))
-                    == id(extract_fake_example_value(right.as_proxy().node))
-                )
-                if op is operator.is_:
-                    return ConstantVariable.create(is_result)
-                else:
-                    return ConstantVariable.create(not is_result)
-
-            if op not in supported_tensor_comparison_op_values:
-                _unimplemented()
-            if (
+        if op in [operator.is_, operator.is_not]:
+            is_result = (
                 isinstance(left, TensorVariable)
                 and isinstance(right, TensorVariable)
-                and (left.size and right.size) is not None
-                and left.size != right.size
-            ):
-                try:
-                    torch.broadcast_shapes(left.size, right.size)
-                except RuntimeError:
-                    # not broadcastable, can't be compared
-                    _unimplemented()
-            tensor_cls = left if isinstance(left, TensorVariable) else right
-            proxy = tx.output.create_proxy(
-                "call_function", op, (left.as_proxy(), right.as_proxy()), {}
+                and id(extract_fake_example_value(left.as_proxy().node))
+                == id(extract_fake_example_value(right.as_proxy().node))
             )
-            return wrap_fx_proxy_cls(
-                type(tensor_cls),  # handle Ndarrays and Tensors
-                tx,
-                proxy,
-            )
+            if op is operator.is_:
+                return ConstantVariable.create(is_result)
+            else:
+                return ConstantVariable.create(not is_result)
 
-        if isinstance(left, SymNodeVariable) or isinstance(right, SymNodeVariable):
-            if op not in supported_tensor_comparison_op_values:
-                _unimplemented()
-
-            proxy = tx.output.create_proxy(
-                "call_function", op, (left.as_proxy(), right.as_proxy()), {}
-            )
-            return SymNodeVariable.create(
-                tx,
-                proxy,
-                sym_num=None,
-            )
-
-        if isinstance(left, UserDefinedObjectVariable) and isinstance(
-            right, UserDefinedObjectVariable
+        if op not in supported_tensor_comparison_op_values:
+            unimplemented(f"{op.__name__}({left}, {right})")
+        if (
+            isinstance(left, TensorVariable)
+            and isinstance(right, TensorVariable)
+            and (left.size and right.size) is not None
+            and left.size != right.size
         ):
-            return ConstantVariable.create(op(left.value, right.value))
+            try:
+                torch.broadcast_shapes(left.size, right.size)
+            except RuntimeError:
+                # not broadcastable, can't be compared
+                unimplemented(f"{op.__name__}({left}, {right})")
+        tensor_cls = left if isinstance(left, TensorVariable) else right
+        proxy = tx.output.create_proxy(
+            "call_function", op, (left.as_proxy(), right.as_proxy()), {}
+        )
+        return wrap_fx_proxy_cls(
+            type(tensor_cls),  # handle Ndarrays and Tensors
+            tx,
+            proxy,
+        )
 
-        if isinstance(left, (StreamVariable, EventVariable)) or isinstance(
-            right, (StreamVariable, EventVariable)
-        ):
-            if type(left) == type(right) and op is operator.eq:
-                return ConstantVariable(op(left.value, right.value))
+    def _comparison_with_symnode(self, tx, left, right):
+        from .tensor import supported_tensor_comparison_op_values
 
-            if isinstance(right, ConstantVariable) or isinstance(
-                left, ConstantVariable
-            ):
-                return ConstantVariable(op(left.value, right.value))
+        op = self.fn
 
-        if op.__name__.startswith("is_"):
-            # If the two objects are of different type, we can safely return False and True for `is` and `is not`, respectively
-            if type(left) is not type(right):
-                return ConstantVariable.create(op.__name__ != "is_")
+        if op not in supported_tensor_comparison_op_values:
+            unimplemented(f"{op.__name__}({left}, {right})")
 
-        if isinstance(left, BuiltinVariable) and isinstance(right, BuiltinVariable):
-            return ConstantVariable.create(op(left.fn, right.fn))
-
-        _unimplemented()
+        proxy = tx.output.create_proxy(
+            "call_function", op, (left.as_proxy(), right.as_proxy()), {}
+        )
+        return SymNodeVariable.create(
+            tx,
+            proxy,
+            sym_num=None,
+        )
 
     def call_and_(self, tx, a, b):
         # Rely on constant_handler
@@ -1711,14 +1796,8 @@
 
         return None
 
-    call_eq = _comparison
-    call_gt = _comparison
-    call_lt = _comparison
-    call_ge = _comparison
-    call_le = _comparison
-    call_ne = _comparison
-    call_is_ = _comparison
-    call_is_not = _comparison
+    def call_contains(self, tx, a: VariableTracker, b: VariableTracker):
+        return a.call_method(tx, "__contains__", [b], {})
 
     call_all = _polyfill_call_impl("all")
     call_any = _polyfill_call_impl("any")