SetVariable in dynamo (#103205)

Set initial
Fixes https://github.com/pytorch/pytorch/issues/94738

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103205
Approved by: https://github.com/jansel
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index e787d93..107f105 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -6072,6 +6072,151 @@
         self.assertIsNotNone(r.model)
         self.assertIsNotNone(r.failed_source_expr)
 
+    def test_simple_set_usage(self):
+        def foo(x, y):
+            setty = {x, y}
+            return setty.pop() * setty.pop()
+
+        counter = CompileCounter()
+        foo = torch._dynamo.optimize(counter, nopython=True)(foo)
+        x = torch.randn(10, 10)
+        y = torch.randn(10, 10)
+        foo(x, y)
+        self.assertEqual(counter.frame_count, 1)
+
+    def test_add_to_set(self):
+        def foo(x, y):
+            setty = set()
+            setty.add(x[0])
+            setty.add(x[1])
+            setty.add(x[2])
+            setty.add(y)
+            return y * len(setty)
+
+        x = torch.randn(10, 10)
+        y = torch.randn(2, 2)
+        eager_result = foo([x, x, x, x, y], y)
+
+        counter = CompileCounter()
+        foo = torch._dynamo.optimize(counter, nopython=True)(foo)
+        result = foo([x, x, x, x, y], y)
+        self.assertEqual(counter.frame_count, 1)
+        self.assertEqual(result, eager_result)
+
+    def test_iter_set(self):
+        def foo(x, y):
+            setty = set()
+            for t in x:
+                setty.add(t)
+            return y * len(setty)
+
+        x = torch.randn(10, 10)
+        y = torch.randn(2, 2)
+        eager_result = foo([x, x, x, x, y], y)
+
+        counter = CompileCounter()
+        foo = torch._dynamo.optimize(counter, nopython=True)(foo)
+        result = foo([x, x, x, x, y], y)
+        self.assertEqual(counter.frame_count, 1)
+        self.assertEqual(result, eager_result)
+
+    def test_input_set_graph_break(self):
+        def foo(x):
+            return x.pop() * x.pop()
+
+        x = torch.randn(10, 10)
+        y = torch.randn(10, 10)
+
+        counter = CompileCounter()
+
+        inp = {x, x, x, x, y, y}
+        foo = torch._dynamo.optimize(counter, nopython=True)(foo)
+
+        # There's a lot of stuff about sets that cannot work without a good deal of exertion on our part.
+        # Specifically, getting a set as input won't ever work with how GetItemSource works (Can't arbitrary access set contents)
+        # and so the guard story for the objects passed into input just isn't there atm.
+        with self.assertRaisesRegex(
+            torch._dynamo.exc.Unsupported,
+            "^call_method UserDefinedObjectVariable\\(set\\).*",
+        ):
+            foo(inp)
+
+        foo = torch._dynamo.optimize(counter, nopython=False)(foo)
+        foo(inp)
+        self.assertEqual(counter.frame_count, 1)
+
+    def test_reconstruct_set_across_graph_break(self):
+        def foo(x, y):
+            setty = set()
+            for t in x:
+                setty.add(t)
+            print("Break!")
+            return y * len(setty)
+
+        x = torch.randn(10, 10)
+        y = torch.randn(2, 2)
+
+        counter = CompileCounter()
+        foo = torch._dynamo.optimize(counter)(foo)
+        result = foo([x, x, x, x, y], y)
+
+    def test_set_aliasing_recompiles(self):
+        g1 = torch.randn(10)
+        g2 = torch.randn(10)
+        g3 = torch.randn(10)
+        g4 = torch.randn(10)
+
+        def foo(a, b, c):
+            myset = {g1, a, b, c}
+            return a + len(myset)
+
+        counter = CompileCounter()
+        foo = torch._dynamo.optimize(counter)(foo)
+        # first call with no aliasing
+        foo(g2, g3, g4)
+        self.assertEqual(counter.frame_count, 1)
+
+        # no aliasing again
+        foo(g3, g2, g4)
+        # assert no recompile
+        self.assertEqual(counter.frame_count, 1)
+
+        # aliasing changes, we should recompile
+        foo(g2, g2, g2)
+        self.assertEqual(counter.frame_count, 2)
+
+        # same aliasing, different tensor
+        foo(g3, g3, g3)
+        self.assertEqual(counter.frame_count, 2)
+
+        # aliasing between global and arg, should recompile again
+        foo(g1, g1, g1)
+        self.assertEqual(counter.frame_count, 3)
+
+        # Reset
+        torch._dynamo.reset()
+
+        # aliasing between global and arg, first call
+        foo(g1, g1, g1)
+        self.assertEqual(counter.frame_count, 4)
+
+        # same aliasing, different tensor, all local, recompile
+        foo(g3, g3, g3)
+        self.assertEqual(counter.frame_count, 5)
+
+        # aliasing same tensor, we shouldn't recompile
+        foo(g2, g2, g2)
+        self.assertEqual(counter.frame_count, 5)
+
+        # No aliasing
+        foo(g2, g3, g4)
+        self.assertEqual(counter.frame_count, 6)
+
+        # No aliasing again
+        foo(g3, g2, g4)
+        # assert no recompile
+        self.assertEqual(counter.frame_count, 6)
+
 
 class TestTracer(JitTestCase):
     def test_jit_save(self):
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
index ab586e0..7599140 100644
--- a/torch/_dynamo/guards.py
+++ b/torch/_dynamo/guards.py
@@ -3,6 +3,7 @@
 import collections
 import dataclasses
 import enum
+import functools
 import importlib
 import itertools
 import logging
@@ -18,7 +19,11 @@
 
 import torch
 import torch.utils._device
-from torch._dynamo.source import TensorProperty, TensorPropertySource
+from torch._dynamo.source import (
+    is_from_local_source,
+    TensorProperty,
+    TensorPropertySource,
+)
 
 from torch._guards import (
     DuplicateInputs,
@@ -1159,3 +1164,29 @@
         if x not in seen:
             yield x
             seen.add(x)
+
+
+def make_dupe_guard(obj_source, dupe_source):
+    # Note - we may end up in a situation where we invoke something like
+    # def fn(x, y)
+    # with fn(x, x)
+    # Prior to the addition of tracking to all relevant objects, we would handle this just fine by
+    # eagerly re-entering VB and rewrapping inputs, correctly creating graphargs and placeholders. However,
+    # with tracking on inputs, duplicate inputs or aliased relationships may end up getting erased here -
+    # In the the fn(x, x) example call above look like a graph with a single input.
+    # In order to ensure that we do not reuse fn(x, x) for fn(x, y), we create a duplicate input guard.
+
+    # Note - we may not have a source, that is fine, it just means we had an object that is safe to have
+    # leave unsourced - like a local list created and discharged entirely within a local scope.
+    if dupe_source and dupe_source != obj_source:
+        ser_source_is_local = is_from_local_source(dupe_source)
+        source_is_local = is_from_local_source(obj_source)
+        # Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently
+        # reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here,
+        # so maybe we should do this refactor before we land this...
+        # TODO(voz): Combine local and global guard builders.
+        if ser_source_is_local == source_is_local:
+            # Note - this is a little agressive - these being duplicate input does not always matter.
+            # However, this should always be a sound guard to add here.
+            return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source)
+    return None
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index cbbd9cb..75a4496 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -84,6 +84,7 @@
     BaseListVariable,
     ListIteratorVariable,
     ListVariable,
+    SetVariable,
     SliceVariable,
     TupleVariable,
 )
@@ -1277,6 +1278,12 @@
         options = VariableTracker.propagate(items)
         self.push(ListVariable(items, mutable_local=MutableLocal(), **options))
 
+    def BUILD_SET(self, inst):
+        items = self.popn(inst.argval)
+        options = VariableTracker.propagate(items)
+        new_set = SetVariable(self, items, mutable_local=MutableLocal(), **options)
+        self.push(new_set)
+
     def BUILD_LIST_UNPACK(self, inst, cls=ListVariable):
         seqs = self.popn(inst.argval)
         options = VariableTracker.propagate(seqs)
@@ -1361,6 +1368,14 @@
             ),
         )
 
+    def SET_ADD(self, inst):
+        v = self.pop()
+        assert inst.argval > 0
+        obj = self.stack[-inst.arg]
+        assert isinstance(obj, SetVariable)
+        assert obj.mutable_local
+        return obj.call_method(self, "add", [v], {})
+
     def LIST_APPEND(self, inst):
         v = self.pop()
         assert inst.argval > 0
diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py
index 4f5ff13..26306f2 100644
--- a/torch/_dynamo/variables/__init__.py
+++ b/torch/_dynamo/variables/__init__.py
@@ -22,6 +22,7 @@
     ListVariable,
     NamedTupleVariable,
     RangeVariable,
+    SetVariable,
     SliceVariable,
     TupleVariable,
 )
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 979788d..3a934ba 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -27,7 +27,7 @@
 from .. import config, mutation_guard, replay_record, skipfiles
 from ..allowed_functions import is_allowed, is_builtin_callable, is_numpy
 from ..exc import unimplemented
-from ..guards import GuardBuilder
+from ..guards import GuardBuilder, make_dupe_guard
 from ..side_effects import SideEffects
 from ..source import (
     AttrSource,
@@ -35,7 +35,6 @@
     GetItemSource,
     GlobalWeakRefSource,
     is_constant_source,
-    is_from_local_source,
     LocalSource,
     RandomValueSource,
     Source,
@@ -82,6 +81,7 @@
     ListVariable,
     NamedTupleVariable,
     RangeVariable,
+    SetVariable,
     SizeVariable,
     SliceVariable,
     TupleIteratorVariable,
@@ -195,7 +195,7 @@
     def __call__(self, value):
         if value in self.tx.output.side_effects:
             side_effect_result = self.tx.output.side_effects[value]
-            dup_guard = self._make_dupe_guard(side_effect_result)
+            dup_guard = make_dupe_guard(self.source, side_effect_result.source)
             if dup_guard:
                 side_effect_result = side_effect_result.add_guards(
                     self.make_guards(dup_guard)
@@ -208,34 +208,6 @@
             )
         return vt
 
-    def _make_dupe_guard(self, deduped_object):
-        # Note - we may end up in a situation where we invoke something like
-        # def fn(x, y)
-        # with fn(x, x)
-        # Prior to the addition of tracking to all relevant objects, we would handle this just fine by
-        # eagerly re-entering VB and rewrapping inputs, correctly creating graphargs and placeholders. However,
-        # with tracking on inputs, duplicate inputs or aliased relationships may end up getting erased here -
-        # In the the fn(x, x) example call above look like a graph with a single input.
-        # In order to ensure that we do not reuse fn(x, x) for fn(x, y), we create a duplicate input guard.
-
-        # Note - we may not have a source, that is fine, it just means we had an object that is safe to have
-        # leave unsourced - like a local list created and discharged entirely within a local scope.
-        if deduped_object.source and deduped_object.source != self.source:
-            ser_source_is_local = is_from_local_source(deduped_object.source)
-            source_is_local = is_from_local_source(self.source)
-            # Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently
-            # reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here,
-            # so maybe we should do this refactor before we land this...
-            # TODO(voz): Combine local and global guard builders.
-            if ser_source_is_local == source_is_local:
-                # Note - this is a little agressive - these being duplicate input does not always matter.
-                # However, this should always be a sound guard to add here.
-                dup_guard = functools.partial(
-                    GuardBuilder.DUPLICATE_INPUT, source_b=deduped_object.source
-                )
-                return dup_guard
-        return None
-
     def _can_lift_attrs_to_inputs(self, vt):
         if type(vt) in [
             TensorVariable,
@@ -263,6 +235,7 @@
     def list_type(value):
         if is_namedtuple(value):
             return functools.partial(NamedTupleVariable, tuple_cls=type(value))
+        # TODO(voz): Why do we have both this and `BaseListVariable`'s `cls_for`?
         return {
             tuple: TupleVariable,
             list: ListVariable,
@@ -386,6 +359,7 @@
             return self.wrap_tensor(value)
         elif is_namedtuple(value):
             return self.wrap_listlike(value)
+
         elif istype(
             value, (dict, collections.defaultdict, collections.OrderedDict)
         ) and all(
@@ -701,7 +675,9 @@
             ).add_guards(guards)
             for i, item in enumerate(value)
         ]
-        result = self.list_type(value)(output, guards=guards)
+        result = self.list_type(value)(
+            output, mutable_local=MutableLocal(), guards=guards
+        )
         if istype(value, list):
             return self.tx.output.side_effects.track_list(self.source, value, result)
         return result
@@ -899,6 +875,21 @@
         # then the relevant SubgraphTracer will lift it to being an input of
         # the subgraph.
         # See NOTE [HigherOrderOperator tracing design] for more details.
+
+        if not self.tx.output.export:
+            # Export has (supposedly) valid cases for fake tensors as inputs here.
+            # I am not convinced, atm, but out of scope for what this assert was added for (protecting value checks
+            # in real_value_tensor_positive_aliases in the common case)
+            assert not isinstance(value, torch._subclasses.fake_tensor.FakeTensor)
+
+        if value in self.tx.output.real_value_tensor_positive_aliases:
+            stored_value = self.tx.output.real_value_tensor_positive_aliases[value]
+            # TODO(voz): Decently common pattern, refactor at some point.
+            dup_guard = self._make_dupe_guard(stored_value)
+            if dup_guard:
+                stored_value = stored_value.add_guards(self.make_guards(dup_guard))
+            return stored_value
+
         tensor_proxy = self.tx.output.root_tracer.create_graph_input(
             re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value)
         )
@@ -1258,7 +1249,7 @@
     ):
         sizes = [ConstantVariable(x) for x in example_value]
         return SizeVariable(sizes, **options)
-    elif isinstance(example_value, (tuple, list)):
+    elif isinstance(example_value, (tuple, list, set)):
         proxy.node.meta["example_value"] = example_value
         unpacked = []
         for i, val in enumerate(example_value):
@@ -1287,6 +1278,8 @@
             return TupleVariable(unpacked, **options)
         elif istype(example_value, (list, immutable_list)):
             return ListVariable(unpacked, mutable_local=MutableLocal(), **options)
+        elif istype(example_value, set):
+            return SetVariable(tx, unpacked, mutable_local=MutableLocal(), **options)
         else:
             assert example_value.__class__.__module__ == "torch.return_types" or hasattr(
                 example_value, "_fields"
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index 3975f7a..1d3ca4f 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -40,6 +40,7 @@
     BaseListVariable,
     ListIteratorVariable,
     ListVariable,
+    SetVariable,
     SizeVariable,
     TupleIteratorVariable,
     TupleVariable,
@@ -71,7 +72,6 @@
             pow,
             repr,
             round,
-            set,
             str,
             str.format,
             sum,
@@ -762,10 +762,17 @@
             return self._dyn_proxy(tx, *args, **kwargs)
         cls = variables.BaseListVariable.cls_for(self.fn)
         if obj is None:
-            return cls(
-                [],
-                mutable_local=MutableLocal(),
-            )
+            if cls is SetVariable:
+                return cls(
+                    tx,
+                    [],
+                    mutable_local=MutableLocal(),
+                )
+            else:
+                return cls(
+                    [],
+                    mutable_local=MutableLocal(),
+                )
         elif obj.has_unpack_var_sequence(tx):
             guards = set()
             if obj.source and not is_constant_source(obj.source):
@@ -773,6 +780,14 @@
                     guards.add(obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN))
                 else:
                     guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
+            if cls is SetVariable:
+                return cls(
+                    tx,
+                    list(obj.unpack_var_sequence(tx)),
+                    mutable_local=MutableLocal(),
+                    guards=guards,
+                ).add_options(self, obj)
+
             return cls(
                 list(obj.unpack_var_sequence(tx)),
                 mutable_local=MutableLocal(),
@@ -782,6 +797,7 @@
     call_iter = _call_iter_tuple_list
     call_tuple = _call_iter_tuple_list
     call_list = _call_iter_tuple_list
+    call_set = _call_iter_tuple_list
 
     @staticmethod
     def is_supported_call_dict_arg(tx, arg):
@@ -1283,6 +1299,11 @@
                 _unimplemented()
             return BaseListVariable.list_compare(tx, op, left, right)
 
+        if isinstance(left, SetVariable):
+            if not type(left) == type(right):  # Mismatch in BaseListVariable subclasses
+                _unimplemented()
+            return ConstantVariable(op(left._underlying_items, right._underlying_items))
+
         if isinstance(left, TensorVariable):
             from .builder import wrap_fx_proxy
 
diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py
index 7b167f2..2007c70 100644
--- a/torch/_dynamo/variables/lists.py
+++ b/torch/_dynamo/variables/lists.py
@@ -1,4 +1,5 @@
 import collections
+import dataclasses
 import functools
 import inspect
 import operator
@@ -11,6 +12,7 @@
 from .. import variables
 from ..bytecode_transformation import create_call_function, create_instruction
 from ..exc import unimplemented
+from ..guards import make_dupe_guard
 from ..source import GetItemSource
 from ..utils import check_constant_args, guard_if_dyn, namedtuple_fields
 from .base import MutableLocal, VariableTracker
@@ -27,6 +29,7 @@
             slice: SliceVariable,
             torch.Size: SizeVariable,
             tuple: TupleVariable,
+            set: SetVariable,
         }[obj]
 
     def __init__(
@@ -723,3 +726,150 @@
         pytree._tuple_to_str,
         pytree._maybe_str_to_tuple,
     )
+
+
+class SetVariable(VariableTracker):
+    @dataclasses.dataclass
+    class SetElement:
+        vt: VariableTracker
+        underlying_value: Any
+
+        def __hash__(self) -> int:
+            return hash(self.underlying_value)
+
+        def __eq__(self, other: Any) -> bool:
+            if not isinstance(other, SetVariable.SetElement):
+                return False
+            if isinstance(self.vt, variables.TensorVariable):
+                return self.underlying_value is other.underlying_value
+            else:
+                return self.underlying_value == other.underlying_value
+
+    def __init__(
+        self,
+        tx,
+        items: List[VariableTracker],
+        recursively_contains=None,
+        regen_guards=True,
+        **kwargs,
+    ):
+        super().__init__(recursively_contains=recursively_contains, **kwargs)
+        # Note - Set is still backed by a list, because we want set behavior over the contents,
+        assert isinstance(items, list)
+        assert all(isinstance(x, VariableTracker) for x in items)
+
+        self.items = []
+        self._add(tx, items)
+
+        # Sometimes, we know that we have passed in the guards from the items in the set
+        if regen_guards:
+            self.guards.update(VariableTracker.propagate(items)["guards"])
+
+        # Really annoying to store this here - but required because of how
+        # VariableTracker's clone works w/r/t attr setting from dict
+        self.tx = tx
+
+    def as_proxy(self):
+        return [x.as_proxy() for x in self.items]
+
+    def python_type(self):
+        return set
+
+    def reconstruct(self, codegen):
+        codegen.load_import_from("builtins", "set")
+        codegen.foreach(self.items)
+        return [
+            create_instruction("BUILD_SET", arg=len(self.items))
+        ] + create_call_function(1, True)
+
+    # Note - this is only used for producing a set
+    def _as_set_element(self, tx, vt):
+        from .base import VariableTracker
+        from .tensor import TensorVariable
+
+        assert isinstance(vt, VariableTracker)
+
+        if isinstance(vt, TensorVariable):
+            tensor_node = vt.as_proxy().node
+            return SetVariable.SetElement(vt, tensor_node)
+        if isinstance(vt, ConstantVariable):
+            return SetVariable.SetElement(vt, vt.value)
+
+        unimplemented(f"Sets with {type(vt)} NYI")
+
+    @property
+    def _underlying_items(self):
+        underlying_items = set()
+        for current_item in self.items:
+            assert (
+                current_item not in underlying_items
+            ), "Items modeling set invariant violated"
+            underlying_items.add(self._as_set_element(self.tx, current_item))
+        return underlying_items
+
+    def _add(self, tx, item):
+        underlying_items = self._underlying_items
+
+        if isinstance(item, (list, set)):
+            items_to_add = item
+        else:
+            items_to_add = [item]
+
+        for item_to_add in items_to_add:
+            set_element = self._as_set_element(tx, item_to_add)
+            if set_element not in underlying_items:
+                underlying_items.add(set_element)
+                self.items.append(set_element.vt)
+            else:
+                for e in underlying_items:
+                    if hash(set_element) == hash(e):
+                        alias_guard = make_dupe_guard(
+                            e.vt.source, set_element.vt.source
+                        )
+                        if alias_guard:
+                            e.vt = e.vt.add_guards(
+                                {e.vt.source.make_guard(alias_guard)}
+                            )
+
+        return self.items
+
+    def call_method(
+        self,
+        tx,
+        name,
+        args: List[VariableTracker],
+        kwargs: Dict[str, VariableTracker],
+    ) -> "VariableTracker":
+        options = VariableTracker.propagate(self, args, kwargs.values())
+        # Somewhat duplicative of CommonListMethodsVariable - but better than to violate substitution
+        # principles and end up with things like direct item access attempts on a set, or
+        # getitem sources.
+        if name == "add" and args and self.mutable_local:
+            assert not kwargs
+            item = args[0]
+            result = SetVariable(
+                tx,
+                self._add(tx, item),
+                mutable_local=self.mutable_local,
+                regen_guards=False,
+                **options,
+            )
+            tx.replace_all(self, result)
+            return ConstantVariable(None)
+        elif name == "pop" and self.mutable_local:
+            assert not kwargs
+            assert not args
+            items = list(self.items)
+            result = items.pop()
+            tx.replace_all(
+                self,
+                SetVariable(tx, items, regen_guards=False, **options),
+            )
+            return result
+        elif name == "__len__":
+            return ConstantVariable(len(self.items)).add_options(options)
+        else:
+            return super().call_method(tx, name, args, kwargs)
+
+    def getitem_const(self, arg: VariableTracker):
+        raise RuntimeError("Illegal to getitem on a set")