[dynamo] Refactor COMPARE_OP and comparison builtins (#122043)
This removes the duplicate handling of comparison ops between symbolic_convert and bultin and refactors the handling to use the binop infrastructure. This change regresses overheads a bit, but this is fixed in the next PR.
New test skips are variants of `type(e) is np.ndarray` previously falling back to eager.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122043
Approved by: https://github.com/anijain2305
ghstack dependencies: #122039
diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv
index 09b5510..a5ff846 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv
@@ -86,7 +86,7 @@
-detectron2_fcos_r_50_fpn,pass,35
+detectron2_fcos_r_50_fpn,pass,94
diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv
index cca1164..1527a2c 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv
@@ -54,47 +54,47 @@
-detectron2_fasterrcnn_r_101_c4,pass,51
+detectron2_fasterrcnn_r_101_c4,pass,164
-detectron2_fasterrcnn_r_101_dc5,pass,51
+detectron2_fasterrcnn_r_101_dc5,pass,163
-detectron2_fasterrcnn_r_101_fpn,pass,55
+detectron2_fasterrcnn_r_101_fpn,pass,172
-detectron2_fasterrcnn_r_50_c4,pass,51
+detectron2_fasterrcnn_r_50_c4,pass,113
-detectron2_fasterrcnn_r_50_dc5,pass,51
+detectron2_fasterrcnn_r_50_dc5,pass,112
-detectron2_fasterrcnn_r_50_fpn,pass,55
+detectron2_fasterrcnn_r_50_fpn,pass,121
-detectron2_fcos_r_50_fpn,pass,38
+detectron2_fcos_r_50_fpn,pass,97
-detectron2_maskrcnn_r_101_c4,fail_accuracy,66
+detectron2_maskrcnn_r_101_c4,pass,182
-detectron2_maskrcnn_r_101_fpn,pass,73
+detectron2_maskrcnn_r_101_fpn,pass,192
-detectron2_maskrcnn_r_50_c4,pass,66
+detectron2_maskrcnn_r_50_c4,pass,131
-detectron2_maskrcnn_r_50_fpn,pass,73
+detectron2_maskrcnn_r_50_fpn,pass,141
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv
index ae1f61d..5b9f79f 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv
@@ -86,7 +86,7 @@
-detectron2_fcos_r_50_fpn,pass,35
+detectron2_fcos_r_50_fpn,pass,94
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv
index 4753fb4..c701671 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv
@@ -54,7 +54,7 @@
-detectron2_fcos_r_50_fpn,pass,38
+detectron2_fcos_r_50_fpn,pass,97
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv
index 02881ad..eacccb9 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv
@@ -86,7 +86,7 @@
-detectron2_fcos_r_50_fpn,pass,36
+detectron2_fcos_r_50_fpn,pass,95
diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv
index 09b5510..a5ff846 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv
@@ -86,7 +86,7 @@
-detectron2_fcos_r_50_fpn,pass,35
+detectron2_fcos_r_50_fpn,pass,94
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv
index bd93461..dac50db 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv
@@ -86,7 +86,7 @@
-detectron2_fcos_r_50_fpn,pass,36
+detectron2_fcos_r_50_fpn,pass,95
diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py
index 5f3343b..4f3deb8 100644
--- a/test/dynamo/test_functions.py
+++ b/test/dynamo/test_functions.py
@@ -1460,6 +1460,26 @@
return par_mul(x)
@make_test
+ def test_list_add_then_mutate(x):
+ my_list = [1, x]
+ y = x / 4.0
+ my_list = my_list + [x / 2.0, 4]
+ my_list.append(y)
+ return sum(my_list)
+
+ @make_test
+ def test_list_expand_lhs(x):
+ return sum(4 * [x])
+
+ @make_test
+ def test_in_not_in(x):
+ mylist = [1, 2, 3, 4, 5, x]
+ myotherlist = [1, 2, 3, 4, 5]
+ assert 3 in mylist
+ assert 6 not in myotherlist
+ return sum(mylist)
+
+ @make_test
def test_partials_udf_kwarg(x):
par_mul = functools.partial(udf_mul, y=torch.ones(10, 10))
return par_mul(x)
diff --git a/test/dynamo_expected_failures/TestFX.test_pytree_concrete b/test/dynamo_skips/TestHistogramdd.test_bins_array
similarity index 100%
copy from test/dynamo_expected_failures/TestFX.test_pytree_concrete
copy to test/dynamo_skips/TestHistogramdd.test_bins_array
diff --git a/test/dynamo_expected_failures/TestFX.test_pytree_concrete b/test/dynamo_skips/TestSqueeze.test_squeeze_type
similarity index 100%
rename from test/dynamo_expected_failures/TestFX.test_pytree_concrete
rename to test/dynamo_skips/TestSqueeze.test_squeeze_type
diff --git a/test/dynamo_expected_failures/TestFX.test_pytree_concrete b/test/dynamo_skips/TestSubscripting.test_test_zero_rank
similarity index 100%
copy from test/dynamo_expected_failures/TestFX.test_pytree_concrete
copy to test/dynamo_skips/TestSubscripting.test_test_zero_rank
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index f37dcbb..57e45a5 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -99,17 +99,11 @@
UnknownVariable,
)
from .variables.nn_module import NNModuleVariable
-from .variables.tensor import (
- supported_comparison_ops,
- supported_const_comparison_ops,
- SymNodeVariable,
- TensorVariable,
-)
+from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
from .variables.user_defined import (
RemovableHandleVariable,
UserDefinedClassVariable,
UserDefinedObjectVariable,
- UserDefinedVariable,
)
log = logging.getLogger(__name__)
@@ -117,6 +111,17 @@
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
tls = threading.local()
+compare_op_handlers: Dict[str, Any] = {
+ k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items()
+}
+handle_contains = BuiltinVariable(operator.contains).call_function
+handle_not = BuiltinVariable(operator.not_).call_function
+compare_op_handlers["in"] = lambda tx, args, _: handle_contains(
+ tx, [*reversed(args)], {}
+)
+compare_op_handlers["not in"] = lambda tx, args, _: handle_not(
+ tx, [handle_contains(tx, [*reversed(args)], {})], {}
+)
@dataclasses.dataclass
@@ -1200,49 +1205,7 @@
unimplemented(f"FOR_ITER {typestr(it)}")
def COMPARE_OP(self, inst):
- left, right = self.popn(2)
- op = inst.argval
- if op == "in" or op == "not in":
- self.push(right.call_method(self, "__contains__", [left], {}))
- if op == "not in":
- self.UNARY_NOT(inst)
- return
-
- if right.is_python_constant():
- if left.is_python_constant():
- # constant fold
- return self.push(
- ConstantVariable(
- supported_comparison_ops[op](
- left.as_python_constant(), right.as_python_constant()
- ),
- )
- )
- elif (
- op in supported_const_comparison_ops
- and right.as_python_constant() is None
- and isinstance(
- left,
- (
- TensorVariable,
- SymNodeVariable,
- NNModuleVariable,
- BaseListVariable,
- UserDefinedVariable,
- BaseUserFunctionVariable,
- ConstDictVariable,
- ),
- )
- ):
- # <non-None> is None
- return self.push(
- ConstantVariable(supported_const_comparison_ops[op](object(), None))
- )
- self.push(
- BuiltinVariable(supported_comparison_ops[op]).call_function(
- self, [left, right], {}
- )
- )
+ self.push(compare_op_handlers[inst.argval](self, self.popn(2), {}))
def GET_ITER(self, inst):
self.call_function(BuiltinVariable(iter), [self.pop()], {})
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index 5fedf48..4a14d4b 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -38,7 +38,7 @@
proxy_args_kwargs,
tensortype_to_dtype,
)
-from .base import MutableLocal, typestr, VariableTracker
+from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
from .ctx_manager import EventVariable, StreamVariable
from .dicts import (
@@ -58,6 +58,7 @@
)
from .tensor import (
FakeItemVariable,
+ supported_comparison_ops,
SymNodeVariable,
TensorVariable,
UnspecializedPythonVariable,
@@ -167,6 +168,9 @@
operator.ior,
operator.index,
}
+ from .tensor import supported_comparison_ops
+
+ fns.update(supported_comparison_ops.values())
fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
return fns
@@ -262,7 +266,17 @@
# Multiple dispatch mechanism defining custom binop behavior for certain type
# combinations. Handlers are attempted in order, and will be used if the type checks
# match. They are expected to have the signature:
- # fn(tx, arg0: VariableTracker, arg1: VariableTracker, options) -> VariableTracker
+ # fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker
+ from .dicts import DictKeys, SetVariable
+ from .functions import BaseUserFunctionVariable, UserFunctionVariable
+ from .nn_module import NNModuleVariable
+ from .tensor import supported_const_comparison_ops
+ from .torch import BaseTorchVariable
+ from .user_defined import (
+ UserDefinedClassVariable,
+ UserDefinedObjectVariable,
+ UserDefinedVariable,
+ )
# Override table contains: op_fn -> [list of handlers]
op_handlers = {}
@@ -280,7 +294,7 @@
tx,
a,
b,
- options,
+ *,
forward_name=forward_name,
reverse_name=reverse_name,
):
@@ -310,9 +324,7 @@
((VariableTracker, UserDefinedVariable), user_defined_handler)
)
- def user_defined_inplace_handler(
- tx, a, b, options, forward_name=inplace_name
- ):
+ def user_defined_inplace_handler(tx, a, b, *, forward_name=inplace_name):
return a.call_method(tx, forward_name, [b], {})
op_handlers[in_place_op].append(
@@ -323,7 +335,7 @@
)
# Dynamic shape args
- def dynamic_handler(tx, a, b, options, fn=op):
+ def dynamic_handler(tx, a, b, *, fn=op):
from .builder import wrap_fx_proxy
return wrap_fx_proxy(
@@ -331,7 +343,6 @@
tx.output.create_proxy(
"call_function", fn, *proxy_args_kwargs([a, b], {})
),
- **options,
)
op_handlers[op].append(
@@ -352,11 +363,11 @@
# Special cases - lower precedence but still prefer these over constant folding
# List-like addition (e.g. [1, 2] + [3, 4])
- def tuple_add_handler(tx, a, b, options):
- return TupleVariable(a.items + list(b.unpack_var_sequence(tx)), **options)
+ def tuple_add_handler(tx, a, b):
+ return TupleVariable([*a.items, *b.unpack_var_sequence(tx)])
- def size_add_handler(tx, a, b, options):
- return SizeVariable(a.items + list(b.unpack_var_sequence(tx)), **options)
+ def size_add_handler(tx, a, b):
+ return SizeVariable([*a.items, *b.unpack_var_sequence(tx)])
list_like_addition_handlers = [
# NB: Prefer the tuple-specific logic over base logic because of
@@ -376,18 +387,27 @@
),
(
(ConstantVariable, TupleVariable),
- lambda tx, a, b, options: TupleVariable(
- list(a.unpack_var_sequence(tx)) + b.items, **options
+ lambda tx, a, b: TupleVariable(
+ [*a.unpack_var_sequence(tx), *b.items],
+ ),
+ ),
+ (
+ (
+ ListVariable,
+ (BaseListVariable, ConstantVariable, ListIteratorVariable),
+ ),
+ lambda tx, a, b: ListVariable(
+ [*a.items, *b.unpack_var_sequence(tx)], mutable_local=MutableLocal()
),
),
(
(BaseListVariable, BaseListVariable),
- lambda tx, a, b, options: type(a)(a.items + b.items, **options),
+ lambda tx, a, b: type(a)([*a.items, *b.items]),
),
]
op_handlers[operator.add].extend(list_like_addition_handlers)
- def list_iadd_handler(tx, a, b, _):
+ def list_iadd_handler(tx, a, b):
if not a.mutable_local or not b.has_unpack_var_sequence(tx):
# Handler doesn't apply
return None
@@ -414,30 +434,169 @@
op_handlers[operator.iadd].extend(list_like_iadd_handlers)
# List-like expansion (e.g. [1, 2, 3] * 3)
- def expand_list_like(tx, lst, const, options):
+ def expand_list_like(tx, lst, const):
+ if isinstance(lst, ConstantVariable):
+ lst, const = const, lst
return lst.__class__(
items=lst.items * const.as_python_constant(),
mutable_local=MutableLocal(),
- **options,
)
list_like_expansion_handlers = [
((ListVariable, ConstantVariable), expand_list_like),
((TupleVariable, ConstantVariable), expand_list_like),
- (
- (ConstantVariable, ListVariable),
- lambda tx, a, b, options: expand_list_like(tx, b, a, options),
- ),
- (
- (ConstantVariable, TupleVariable),
- lambda tx, a, b, options: expand_list_like(tx, b, a, options),
- ),
+ ((ConstantVariable, ListVariable), expand_list_like),
+ ((ConstantVariable, TupleVariable), expand_list_like),
]
op_handlers[operator.mul].extend(list_like_expansion_handlers)
- for key in op_handlers.keys():
- # insert a mutable cache into each entry
- op_handlers[key] = (op_handlers[key], dict())
+ size_or_tuple = (SizeVariable, TupleVariable)
+ has_set_items = (SetVariable, DictKeys)
+
+ def create_cmp_op_handlers(op):
+ def compare_by_value(tx, a, b):
+ return ConstantVariable(op(a.value, b.value))
+
+ result = [((ConstantVariable, ConstantVariable), compare_by_value)]
+
+ if op in supported_const_comparison_ops.values():
+ # Tensor is None, List is not None, etc
+ none_result = op(object(), None)
+ if op.__name__.startswith("is_"):
+
+ def never(tx, a, b):
+ return ConstantVariable(none_result)
+
+ obj_op_none = never
+ none_op_obj = never
+ else:
+
+ def obj_op_none(tx, a, b: ConstantVariable):
+ if b.value is None or b.value is True or b.value is False:
+ return ConstantVariable(none_result)
+
+ def none_op_obj(tx, a: ConstantVariable, b):
+ if a.value is None or a.value is True or a.value is False:
+ return ConstantVariable(none_result)
+
+ types_that_are_never_none = (
+ TensorVariable,
+ SymNodeVariable,
+ NNModuleVariable,
+ BaseListVariable,
+ UserDefinedVariable,
+ BaseUserFunctionVariable,
+ ConstDictVariable,
+ BaseTorchVariable,
+ )
+ result.extend(
+ [
+ (
+ (types_that_are_never_none, ConstantVariable),
+ obj_op_none,
+ ),
+ (
+ (ConstantVariable, types_that_are_never_none),
+ none_op_obj,
+ ),
+ ]
+ )
+
+ def list_compare_nocheck(tx, left, right):
+ return BaseListVariable.list_compare(tx, op, left, right)
+
+ def list_compare_check(tx, left, right):
+ if type(left) is not type(
+ right
+ ): # Mismatch in BaseListVariable subclasses
+ unimplemented(f"{op.__name__}({left}, {right})")
+ return BaseListVariable.list_compare(tx, op, left, right)
+
+ def compare_set_items(tx, left, right):
+ return ConstantVariable(op(left.set_items, right.set_items))
+
+ op_var = BuiltinVariable(op)
+ result.extend(
+ [
+ (
+ (
+ (UserFunctionVariable, BuiltinVariable),
+ (UserFunctionVariable, BuiltinVariable),
+ ),
+ lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)),
+ ),
+ (
+ (
+ NNModuleVariable,
+ NNModuleVariable,
+ ),
+ lambda tx, a, b: ConstantVariable(
+ op(
+ tx.output.get_submodule(a.module_key),
+ tx.output.get_submodule(b.module_key),
+ )
+ ),
+ ),
+ ((size_or_tuple, size_or_tuple), list_compare_nocheck),
+ (
+ (variables.BaseListVariable, variables.BaseListVariable),
+ list_compare_check,
+ ),
+ ((has_set_items, has_set_items), compare_set_items),
+ # TODO(jansel): UserDefinedObjectVariable is wrong and could invoke user code
+ (
+ (UserDefinedObjectVariable, UserDefinedObjectVariable),
+ compare_by_value,
+ ),
+ (
+ (UserDefinedClassVariable, UserDefinedClassVariable),
+ compare_by_value,
+ ),
+ (
+ (
+ (StreamVariable, EventVariable, ConstantVariable),
+ (StreamVariable, EventVariable, ConstantVariable),
+ ),
+ compare_by_value,
+ ),
+ (
+ (TensorVariable, VariableTracker),
+ op_var._comparison_with_tensor,
+ ),
+ (
+ (VariableTracker, TensorVariable),
+ op_var._comparison_with_tensor,
+ ),
+ (
+ (SymNodeVariable, VariableTracker),
+ op_var._comparison_with_symnode,
+ ),
+ (
+ (VariableTracker, SymNodeVariable),
+ op_var._comparison_with_symnode,
+ ),
+ ]
+ )
+
+ if op.__name__.startswith("is_"):
+
+ def handle_is(tx, left, right):
+ # If the two objects are of different type, we can safely return False
+ # and True for `is` and `is not`, respectively
+ if type(left) is not type(right):
+ return ConstantVariable.create(op.__name__ != "is_")
+
+ result.append(((VariableTracker, VariableTracker), handle_is))
+
+ return result
+
+ for op in supported_comparison_ops.values():
+ assert callable(op)
+ assert op not in op_handlers
+ op_handlers[op] = create_cmp_op_handlers(op)
+
+ for op in op_handlers.keys():
+ op_handlers[op] = (op_handlers[op], dict())
return op_handlers
@staticmethod
@@ -452,13 +611,26 @@
if hit is not False:
return hit
- # Return first handler that matches the type checks
+ a_type = type(a)
+ b_type = type(b)
+ matches = []
for (type1, type2), handler in handlers:
- if isinstance(a, type1) and isinstance(b, type2):
- cache[cache_key] = handler
- return handler
- cache[cache_key] = None
- return None
+ if issubclass(a_type, type1) and issubclass(b_type, type2):
+ matches.append(handler)
+
+ if not matches:
+ result = None
+ elif len(matches) == 1:
+ result = matches[0]
+ else:
+
+ def result(*args):
+ for fn in matches:
+ rv = fn(*args)
+ if rv:
+ return rv
+
+ return result
def can_insert_in_graph(self):
return self.fn in self._fx_graph_functions()
@@ -666,7 +838,7 @@
# Try to find a handler for the arg types; otherwise, fall through to constant handler
binop_handler = BuiltinVariable._find_binop_handler(fn, args[0], args[1])
if binop_handler:
- res = binop_handler(tx, args[0], args[1], {})
+ res = binop_handler(tx, args[0], args[1])
if res is not None:
return res
@@ -1511,150 +1683,63 @@
def call_deepcopy(self, tx, x):
unimplemented(f"copy.deepcopy {repr(x)}")
- def _comparison(self, tx, left, right):
- """
- Used to implement comparison operators for different types.
- For example, list1 < list2 is implemented differently from tensor1 < tensor2
- """
- from . import (
- BaseListVariable,
- ConstantVariable,
- NNModuleVariable,
- TensorVariable,
- UserDefinedObjectVariable,
- UserFunctionVariable,
- )
- from .lists import SizeVariable
- from .tensor import (
- supported_const_comparison_op_values,
- supported_tensor_comparison_op_values,
- )
+ def _comparison_with_tensor(self, tx, left, right):
+ from .builder import wrap_fx_proxy_cls
+ from .tensor import supported_tensor_comparison_op_values
op = self.fn
- def _unimplemented():
- unimplemented(f"comparison {typestr(left)} {op} {typestr(right)}")
-
- if (
- all(
- isinstance(x, (NNModuleVariable, ConstantVariable))
- for x in [left, right]
- )
- and op in supported_const_comparison_op_values
- ):
- left = (
- tx.output.get_submodule(left.module_key)
- if isinstance(left, NNModuleVariable)
- else left.as_python_constant()
- )
- right = (
- tx.output.get_submodule(right.module_key)
- if isinstance(right, NNModuleVariable)
- else right.as_python_constant()
- )
- return ConstantVariable.create(op(left, right))
-
- if isinstance(left, UserFunctionVariable):
- if op not in supported_const_comparison_op_values:
- _unimplemented()
- if not isinstance(right, UserFunctionVariable):
- _unimplemented()
- return ConstantVariable.create(op(left.fn, right.fn))
-
- # Note, we have a rare BaseListVariable subtype mismatch with valid comparison
- # x = torch.randn([3, 3])
- # x.size() == (3, 3) # True
- # (3, 3) == x.size() # True
- if isinstance(left, (SizeVariable, TupleVariable)) and isinstance(
- right, (TupleVariable, SizeVariable)
- ):
- return BaseListVariable.list_compare(tx, op, left, right)
-
- if isinstance(left, BaseListVariable):
- if not type(left) == type(right): # Mismatch in BaseListVariable subclasses
- _unimplemented()
- return BaseListVariable.list_compare(tx, op, left, right)
-
- # If they implement set semantics (e.g. SetVariable or DictKeys)
- if hasattr(left, "set_items") and hasattr(right, "set_items"):
- return ConstantVariable.create(op(left.set_items, right.set_items))
-
- if isinstance(left, TensorVariable) or isinstance(right, TensorVariable):
- from .builder import wrap_fx_proxy_cls
-
- if op in [operator.is_, operator.is_not]:
- is_result = (
- isinstance(left, TensorVariable)
- and isinstance(right, TensorVariable)
- and id(extract_fake_example_value(left.as_proxy().node))
- == id(extract_fake_example_value(right.as_proxy().node))
- )
- if op is operator.is_:
- return ConstantVariable.create(is_result)
- else:
- return ConstantVariable.create(not is_result)
-
- if op not in supported_tensor_comparison_op_values:
- _unimplemented()
- if (
+ if op in [operator.is_, operator.is_not]:
+ is_result = (
isinstance(left, TensorVariable)
and isinstance(right, TensorVariable)
- and (left.size and right.size) is not None
- and left.size != right.size
- ):
- try:
- torch.broadcast_shapes(left.size, right.size)
- except RuntimeError:
- # not broadcastable, can't be compared
- _unimplemented()
- tensor_cls = left if isinstance(left, TensorVariable) else right
- proxy = tx.output.create_proxy(
- "call_function", op, (left.as_proxy(), right.as_proxy()), {}
+ and id(extract_fake_example_value(left.as_proxy().node))
+ == id(extract_fake_example_value(right.as_proxy().node))
)
- return wrap_fx_proxy_cls(
- type(tensor_cls), # handle Ndarrays and Tensors
- tx,
- proxy,
- )
+ if op is operator.is_:
+ return ConstantVariable.create(is_result)
+ else:
+ return ConstantVariable.create(not is_result)
- if isinstance(left, SymNodeVariable) or isinstance(right, SymNodeVariable):
- if op not in supported_tensor_comparison_op_values:
- _unimplemented()
-
- proxy = tx.output.create_proxy(
- "call_function", op, (left.as_proxy(), right.as_proxy()), {}
- )
- return SymNodeVariable.create(
- tx,
- proxy,
- sym_num=None,
- )
-
- if isinstance(left, UserDefinedObjectVariable) and isinstance(
- right, UserDefinedObjectVariable
+ if op not in supported_tensor_comparison_op_values:
+ unimplemented(f"{op.__name__}({left}, {right})")
+ if (
+ isinstance(left, TensorVariable)
+ and isinstance(right, TensorVariable)
+ and (left.size and right.size) is not None
+ and left.size != right.size
):
- return ConstantVariable.create(op(left.value, right.value))
+ try:
+ torch.broadcast_shapes(left.size, right.size)
+ except RuntimeError:
+ # not broadcastable, can't be compared
+ unimplemented(f"{op.__name__}({left}, {right})")
+ tensor_cls = left if isinstance(left, TensorVariable) else right
+ proxy = tx.output.create_proxy(
+ "call_function", op, (left.as_proxy(), right.as_proxy()), {}
+ )
+ return wrap_fx_proxy_cls(
+ type(tensor_cls), # handle Ndarrays and Tensors
+ tx,
+ proxy,
+ )
- if isinstance(left, (StreamVariable, EventVariable)) or isinstance(
- right, (StreamVariable, EventVariable)
- ):
- if type(left) == type(right) and op is operator.eq:
- return ConstantVariable(op(left.value, right.value))
+ def _comparison_with_symnode(self, tx, left, right):
+ from .tensor import supported_tensor_comparison_op_values
- if isinstance(right, ConstantVariable) or isinstance(
- left, ConstantVariable
- ):
- return ConstantVariable(op(left.value, right.value))
+ op = self.fn
- if op.__name__.startswith("is_"):
- # If the two objects are of different type, we can safely return False and True for `is` and `is not`, respectively
- if type(left) is not type(right):
- return ConstantVariable.create(op.__name__ != "is_")
+ if op not in supported_tensor_comparison_op_values:
+ unimplemented(f"{op.__name__}({left}, {right})")
- if isinstance(left, BuiltinVariable) and isinstance(right, BuiltinVariable):
- return ConstantVariable.create(op(left.fn, right.fn))
-
- _unimplemented()
+ proxy = tx.output.create_proxy(
+ "call_function", op, (left.as_proxy(), right.as_proxy()), {}
+ )
+ return SymNodeVariable.create(
+ tx,
+ proxy,
+ sym_num=None,
+ )
def call_and_(self, tx, a, b):
# Rely on constant_handler
@@ -1711,14 +1796,8 @@
return None
- call_eq = _comparison
- call_gt = _comparison
- call_lt = _comparison
- call_ge = _comparison
- call_le = _comparison
- call_ne = _comparison
- call_is_ = _comparison
- call_is_not = _comparison
+ def call_contains(self, tx, a: VariableTracker, b: VariableTracker):
+ return a.call_method(tx, "__contains__", [b], {})
call_all = _polyfill_call_impl("all")
call_any = _polyfill_call_impl("any")