SetVariable in dynamo (#103205)
Set initial
Fixes https://github.com/pytorch/pytorch/issues/94738
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103205
Approved by: https://github.com/jansel
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index e787d93..107f105 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -6072,6 +6072,151 @@
self.assertIsNotNone(r.model)
self.assertIsNotNone(r.failed_source_expr)
+ def test_simple_set_usage(self):
+ def foo(x, y):
+ setty = {x, y}
+ return setty.pop() * setty.pop()
+
+ counter = CompileCounter()
+ foo = torch._dynamo.optimize(counter, nopython=True)(foo)
+ x = torch.randn(10, 10)
+ y = torch.randn(10, 10)
+ foo(x, y)
+ self.assertEqual(counter.frame_count, 1)
+
+ def test_add_to_set(self):
+ def foo(x, y):
+ setty = set()
+ setty.add(x[0])
+ setty.add(x[1])
+ setty.add(x[2])
+ setty.add(y)
+ return y * len(setty)
+
+ x = torch.randn(10, 10)
+ y = torch.randn(2, 2)
+ eager_result = foo([x, x, x, x, y], y)
+
+ counter = CompileCounter()
+ foo = torch._dynamo.optimize(counter, nopython=True)(foo)
+ result = foo([x, x, x, x, y], y)
+ self.assertEqual(counter.frame_count, 1)
+ self.assertEqual(result, eager_result)
+
+ def test_iter_set(self):
+ def foo(x, y):
+ setty = set()
+ for t in x:
+ setty.add(t)
+ return y * len(setty)
+
+ x = torch.randn(10, 10)
+ y = torch.randn(2, 2)
+ eager_result = foo([x, x, x, x, y], y)
+
+ counter = CompileCounter()
+ foo = torch._dynamo.optimize(counter, nopython=True)(foo)
+ result = foo([x, x, x, x, y], y)
+ self.assertEqual(counter.frame_count, 1)
+ self.assertEqual(result, eager_result)
+
+ def test_input_set_graph_break(self):
+ def foo(x):
+ return x.pop() * x.pop()
+
+ x = torch.randn(10, 10)
+ y = torch.randn(10, 10)
+
+ counter = CompileCounter()
+
+ inp = {x, x, x, x, y, y}
+ foo = torch._dynamo.optimize(counter, nopython=True)(foo)
+
+ # There's a lot of stuff about sets that cannot work without a good deal of exertion on our part.
+ # Specifically, getting a set as input won't ever work with how GetItemSource works (Can't arbitrary access set contents)
+ # and so the guard story for the objects passed into input just isn't there atm.
+ with self.assertRaisesRegex(
+ torch._dynamo.exc.Unsupported,
+ "^call_method UserDefinedObjectVariable\\(set\\).*",
+ ):
+ foo(inp)
+
+ foo = torch._dynamo.optimize(counter, nopython=False)(foo)
+ foo(inp)
+ self.assertEqual(counter.frame_count, 1)
+
+ def test_reconstruct_set_across_graph_break(self):
+ def foo(x, y):
+ setty = set()
+ for t in x:
+ setty.add(t)
+ print("Break!")
+ return y * len(setty)
+
+ x = torch.randn(10, 10)
+ y = torch.randn(2, 2)
+
+ counter = CompileCounter()
+ foo = torch._dynamo.optimize(counter)(foo)
+ result = foo([x, x, x, x, y], y)
+
+ def test_set_aliasing_recompiles(self):
+ g1 = torch.randn(10)
+ g2 = torch.randn(10)
+ g3 = torch.randn(10)
+ g4 = torch.randn(10)
+
+ def foo(a, b, c):
+ myset = {g1, a, b, c}
+ return a + len(myset)
+
+ counter = CompileCounter()
+ foo = torch._dynamo.optimize(counter)(foo)
+ # first call with no aliasing
+ foo(g2, g3, g4)
+ self.assertEqual(counter.frame_count, 1)
+
+ # no aliasing again
+ foo(g3, g2, g4)
+ # assert no recompile
+ self.assertEqual(counter.frame_count, 1)
+
+ # aliasing changes, we should recompile
+ foo(g2, g2, g2)
+ self.assertEqual(counter.frame_count, 2)
+
+ # same aliasing, different tensor
+ foo(g3, g3, g3)
+ self.assertEqual(counter.frame_count, 2)
+
+ # aliasing between global and arg, should recompile again
+ foo(g1, g1, g1)
+ self.assertEqual(counter.frame_count, 3)
+
+ # Reset
+ torch._dynamo.reset()
+
+ # aliasing between global and arg, first call
+ foo(g1, g1, g1)
+ self.assertEqual(counter.frame_count, 4)
+
+ # same aliasing, different tensor, all local, recompile
+ foo(g3, g3, g3)
+ self.assertEqual(counter.frame_count, 5)
+
+ # aliasing same tensor, we shouldn't recompile
+ foo(g2, g2, g2)
+ self.assertEqual(counter.frame_count, 5)
+
+ # No aliasing
+ foo(g2, g3, g4)
+ self.assertEqual(counter.frame_count, 6)
+
+ # No aliasing again
+ foo(g3, g2, g4)
+ # assert no recompile
+ self.assertEqual(counter.frame_count, 6)
+
class TestTracer(JitTestCase):
def test_jit_save(self):
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
index ab586e0..7599140 100644
--- a/torch/_dynamo/guards.py
+++ b/torch/_dynamo/guards.py
@@ -3,6 +3,7 @@
import collections
import dataclasses
import enum
+import functools
import importlib
import itertools
import logging
@@ -18,7 +19,11 @@
import torch
import torch.utils._device
-from torch._dynamo.source import TensorProperty, TensorPropertySource
+from torch._dynamo.source import (
+ is_from_local_source,
+ TensorProperty,
+ TensorPropertySource,
+)
from torch._guards import (
DuplicateInputs,
@@ -1159,3 +1164,29 @@
if x not in seen:
yield x
seen.add(x)
+
+
+def make_dupe_guard(obj_source, dupe_source):
+ # Note - we may end up in a situation where we invoke something like
+ # def fn(x, y)
+ # with fn(x, x)
+ # Prior to the addition of tracking to all relevant objects, we would handle this just fine by
+ # eagerly re-entering VB and rewrapping inputs, correctly creating graphargs and placeholders. However,
+ # with tracking on inputs, duplicate inputs or aliased relationships may end up getting erased here -
+ # In the the fn(x, x) example call above look like a graph with a single input.
+ # In order to ensure that we do not reuse fn(x, x) for fn(x, y), we create a duplicate input guard.
+
+ # Note - we may not have a source, that is fine, it just means we had an object that is safe to have
+ # leave unsourced - like a local list created and discharged entirely within a local scope.
+ if dupe_source and dupe_source != obj_source:
+ ser_source_is_local = is_from_local_source(dupe_source)
+ source_is_local = is_from_local_source(obj_source)
+ # Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently
+ # reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here,
+ # so maybe we should do this refactor before we land this...
+ # TODO(voz): Combine local and global guard builders.
+ if ser_source_is_local == source_is_local:
+ # Note - this is a little agressive - these being duplicate input does not always matter.
+ # However, this should always be a sound guard to add here.
+ return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source)
+ return None
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index cbbd9cb..75a4496 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -84,6 +84,7 @@
BaseListVariable,
ListIteratorVariable,
ListVariable,
+ SetVariable,
SliceVariable,
TupleVariable,
)
@@ -1277,6 +1278,12 @@
options = VariableTracker.propagate(items)
self.push(ListVariable(items, mutable_local=MutableLocal(), **options))
+ def BUILD_SET(self, inst):
+ items = self.popn(inst.argval)
+ options = VariableTracker.propagate(items)
+ new_set = SetVariable(self, items, mutable_local=MutableLocal(), **options)
+ self.push(new_set)
+
def BUILD_LIST_UNPACK(self, inst, cls=ListVariable):
seqs = self.popn(inst.argval)
options = VariableTracker.propagate(seqs)
@@ -1361,6 +1368,14 @@
),
)
+ def SET_ADD(self, inst):
+ v = self.pop()
+ assert inst.argval > 0
+ obj = self.stack[-inst.arg]
+ assert isinstance(obj, SetVariable)
+ assert obj.mutable_local
+ return obj.call_method(self, "add", [v], {})
+
def LIST_APPEND(self, inst):
v = self.pop()
assert inst.argval > 0
diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py
index 4f5ff13..26306f2 100644
--- a/torch/_dynamo/variables/__init__.py
+++ b/torch/_dynamo/variables/__init__.py
@@ -22,6 +22,7 @@
ListVariable,
NamedTupleVariable,
RangeVariable,
+ SetVariable,
SliceVariable,
TupleVariable,
)
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 979788d..3a934ba 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -27,7 +27,7 @@
from .. import config, mutation_guard, replay_record, skipfiles
from ..allowed_functions import is_allowed, is_builtin_callable, is_numpy
from ..exc import unimplemented
-from ..guards import GuardBuilder
+from ..guards import GuardBuilder, make_dupe_guard
from ..side_effects import SideEffects
from ..source import (
AttrSource,
@@ -35,7 +35,6 @@
GetItemSource,
GlobalWeakRefSource,
is_constant_source,
- is_from_local_source,
LocalSource,
RandomValueSource,
Source,
@@ -82,6 +81,7 @@
ListVariable,
NamedTupleVariable,
RangeVariable,
+ SetVariable,
SizeVariable,
SliceVariable,
TupleIteratorVariable,
@@ -195,7 +195,7 @@
def __call__(self, value):
if value in self.tx.output.side_effects:
side_effect_result = self.tx.output.side_effects[value]
- dup_guard = self._make_dupe_guard(side_effect_result)
+ dup_guard = make_dupe_guard(self.source, side_effect_result.source)
if dup_guard:
side_effect_result = side_effect_result.add_guards(
self.make_guards(dup_guard)
@@ -208,34 +208,6 @@
)
return vt
- def _make_dupe_guard(self, deduped_object):
- # Note - we may end up in a situation where we invoke something like
- # def fn(x, y)
- # with fn(x, x)
- # Prior to the addition of tracking to all relevant objects, we would handle this just fine by
- # eagerly re-entering VB and rewrapping inputs, correctly creating graphargs and placeholders. However,
- # with tracking on inputs, duplicate inputs or aliased relationships may end up getting erased here -
- # In the the fn(x, x) example call above look like a graph with a single input.
- # In order to ensure that we do not reuse fn(x, x) for fn(x, y), we create a duplicate input guard.
-
- # Note - we may not have a source, that is fine, it just means we had an object that is safe to have
- # leave unsourced - like a local list created and discharged entirely within a local scope.
- if deduped_object.source and deduped_object.source != self.source:
- ser_source_is_local = is_from_local_source(deduped_object.source)
- source_is_local = is_from_local_source(self.source)
- # Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently
- # reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here,
- # so maybe we should do this refactor before we land this...
- # TODO(voz): Combine local and global guard builders.
- if ser_source_is_local == source_is_local:
- # Note - this is a little agressive - these being duplicate input does not always matter.
- # However, this should always be a sound guard to add here.
- dup_guard = functools.partial(
- GuardBuilder.DUPLICATE_INPUT, source_b=deduped_object.source
- )
- return dup_guard
- return None
-
def _can_lift_attrs_to_inputs(self, vt):
if type(vt) in [
TensorVariable,
@@ -263,6 +235,7 @@
def list_type(value):
if is_namedtuple(value):
return functools.partial(NamedTupleVariable, tuple_cls=type(value))
+ # TODO(voz): Why do we have both this and `BaseListVariable`'s `cls_for`?
return {
tuple: TupleVariable,
list: ListVariable,
@@ -386,6 +359,7 @@
return self.wrap_tensor(value)
elif is_namedtuple(value):
return self.wrap_listlike(value)
+
elif istype(
value, (dict, collections.defaultdict, collections.OrderedDict)
) and all(
@@ -701,7 +675,9 @@
).add_guards(guards)
for i, item in enumerate(value)
]
- result = self.list_type(value)(output, guards=guards)
+ result = self.list_type(value)(
+ output, mutable_local=MutableLocal(), guards=guards
+ )
if istype(value, list):
return self.tx.output.side_effects.track_list(self.source, value, result)
return result
@@ -899,6 +875,21 @@
# then the relevant SubgraphTracer will lift it to being an input of
# the subgraph.
# See NOTE [HigherOrderOperator tracing design] for more details.
+
+ if not self.tx.output.export:
+ # Export has (supposedly) valid cases for fake tensors as inputs here.
+ # I am not convinced, atm, but out of scope for what this assert was added for (protecting value checks
+ # in real_value_tensor_positive_aliases in the common case)
+ assert not isinstance(value, torch._subclasses.fake_tensor.FakeTensor)
+
+ if value in self.tx.output.real_value_tensor_positive_aliases:
+ stored_value = self.tx.output.real_value_tensor_positive_aliases[value]
+ # TODO(voz): Decently common pattern, refactor at some point.
+ dup_guard = self._make_dupe_guard(stored_value)
+ if dup_guard:
+ stored_value = stored_value.add_guards(self.make_guards(dup_guard))
+ return stored_value
+
tensor_proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value)
)
@@ -1258,7 +1249,7 @@
):
sizes = [ConstantVariable(x) for x in example_value]
return SizeVariable(sizes, **options)
- elif isinstance(example_value, (tuple, list)):
+ elif isinstance(example_value, (tuple, list, set)):
proxy.node.meta["example_value"] = example_value
unpacked = []
for i, val in enumerate(example_value):
@@ -1287,6 +1278,8 @@
return TupleVariable(unpacked, **options)
elif istype(example_value, (list, immutable_list)):
return ListVariable(unpacked, mutable_local=MutableLocal(), **options)
+ elif istype(example_value, set):
+ return SetVariable(tx, unpacked, mutable_local=MutableLocal(), **options)
else:
assert example_value.__class__.__module__ == "torch.return_types" or hasattr(
example_value, "_fields"
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index 3975f7a..1d3ca4f 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -40,6 +40,7 @@
BaseListVariable,
ListIteratorVariable,
ListVariable,
+ SetVariable,
SizeVariable,
TupleIteratorVariable,
TupleVariable,
@@ -71,7 +72,6 @@
pow,
repr,
round,
- set,
str,
str.format,
sum,
@@ -762,10 +762,17 @@
return self._dyn_proxy(tx, *args, **kwargs)
cls = variables.BaseListVariable.cls_for(self.fn)
if obj is None:
- return cls(
- [],
- mutable_local=MutableLocal(),
- )
+ if cls is SetVariable:
+ return cls(
+ tx,
+ [],
+ mutable_local=MutableLocal(),
+ )
+ else:
+ return cls(
+ [],
+ mutable_local=MutableLocal(),
+ )
elif obj.has_unpack_var_sequence(tx):
guards = set()
if obj.source and not is_constant_source(obj.source):
@@ -773,6 +780,14 @@
guards.add(obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN))
else:
guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
+ if cls is SetVariable:
+ return cls(
+ tx,
+ list(obj.unpack_var_sequence(tx)),
+ mutable_local=MutableLocal(),
+ guards=guards,
+ ).add_options(self, obj)
+
return cls(
list(obj.unpack_var_sequence(tx)),
mutable_local=MutableLocal(),
@@ -782,6 +797,7 @@
call_iter = _call_iter_tuple_list
call_tuple = _call_iter_tuple_list
call_list = _call_iter_tuple_list
+ call_set = _call_iter_tuple_list
@staticmethod
def is_supported_call_dict_arg(tx, arg):
@@ -1283,6 +1299,11 @@
_unimplemented()
return BaseListVariable.list_compare(tx, op, left, right)
+ if isinstance(left, SetVariable):
+ if not type(left) == type(right): # Mismatch in BaseListVariable subclasses
+ _unimplemented()
+ return ConstantVariable(op(left._underlying_items, right._underlying_items))
+
if isinstance(left, TensorVariable):
from .builder import wrap_fx_proxy
diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py
index 7b167f2..2007c70 100644
--- a/torch/_dynamo/variables/lists.py
+++ b/torch/_dynamo/variables/lists.py
@@ -1,4 +1,5 @@
import collections
+import dataclasses
import functools
import inspect
import operator
@@ -11,6 +12,7 @@
from .. import variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import unimplemented
+from ..guards import make_dupe_guard
from ..source import GetItemSource
from ..utils import check_constant_args, guard_if_dyn, namedtuple_fields
from .base import MutableLocal, VariableTracker
@@ -27,6 +29,7 @@
slice: SliceVariable,
torch.Size: SizeVariable,
tuple: TupleVariable,
+ set: SetVariable,
}[obj]
def __init__(
@@ -723,3 +726,150 @@
pytree._tuple_to_str,
pytree._maybe_str_to_tuple,
)
+
+
+class SetVariable(VariableTracker):
+ @dataclasses.dataclass
+ class SetElement:
+ vt: VariableTracker
+ underlying_value: Any
+
+ def __hash__(self) -> int:
+ return hash(self.underlying_value)
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, SetVariable.SetElement):
+ return False
+ if isinstance(self.vt, variables.TensorVariable):
+ return self.underlying_value is other.underlying_value
+ else:
+ return self.underlying_value == other.underlying_value
+
+ def __init__(
+ self,
+ tx,
+ items: List[VariableTracker],
+ recursively_contains=None,
+ regen_guards=True,
+ **kwargs,
+ ):
+ super().__init__(recursively_contains=recursively_contains, **kwargs)
+ # Note - Set is still backed by a list, because we want set behavior over the contents,
+ assert isinstance(items, list)
+ assert all(isinstance(x, VariableTracker) for x in items)
+
+ self.items = []
+ self._add(tx, items)
+
+ # Sometimes, we know that we have passed in the guards from the items in the set
+ if regen_guards:
+ self.guards.update(VariableTracker.propagate(items)["guards"])
+
+ # Really annoying to store this here - but required because of how
+ # VariableTracker's clone works w/r/t attr setting from dict
+ self.tx = tx
+
+ def as_proxy(self):
+ return [x.as_proxy() for x in self.items]
+
+ def python_type(self):
+ return set
+
+ def reconstruct(self, codegen):
+ codegen.load_import_from("builtins", "set")
+ codegen.foreach(self.items)
+ return [
+ create_instruction("BUILD_SET", arg=len(self.items))
+ ] + create_call_function(1, True)
+
+ # Note - this is only used for producing a set
+ def _as_set_element(self, tx, vt):
+ from .base import VariableTracker
+ from .tensor import TensorVariable
+
+ assert isinstance(vt, VariableTracker)
+
+ if isinstance(vt, TensorVariable):
+ tensor_node = vt.as_proxy().node
+ return SetVariable.SetElement(vt, tensor_node)
+ if isinstance(vt, ConstantVariable):
+ return SetVariable.SetElement(vt, vt.value)
+
+ unimplemented(f"Sets with {type(vt)} NYI")
+
+ @property
+ def _underlying_items(self):
+ underlying_items = set()
+ for current_item in self.items:
+ assert (
+ current_item not in underlying_items
+ ), "Items modeling set invariant violated"
+ underlying_items.add(self._as_set_element(self.tx, current_item))
+ return underlying_items
+
+ def _add(self, tx, item):
+ underlying_items = self._underlying_items
+
+ if isinstance(item, (list, set)):
+ items_to_add = item
+ else:
+ items_to_add = [item]
+
+ for item_to_add in items_to_add:
+ set_element = self._as_set_element(tx, item_to_add)
+ if set_element not in underlying_items:
+ underlying_items.add(set_element)
+ self.items.append(set_element.vt)
+ else:
+ for e in underlying_items:
+ if hash(set_element) == hash(e):
+ alias_guard = make_dupe_guard(
+ e.vt.source, set_element.vt.source
+ )
+ if alias_guard:
+ e.vt = e.vt.add_guards(
+ {e.vt.source.make_guard(alias_guard)}
+ )
+
+ return self.items
+
+ def call_method(
+ self,
+ tx,
+ name,
+ args: List[VariableTracker],
+ kwargs: Dict[str, VariableTracker],
+ ) -> "VariableTracker":
+ options = VariableTracker.propagate(self, args, kwargs.values())
+ # Somewhat duplicative of CommonListMethodsVariable - but better than to violate substitution
+ # principles and end up with things like direct item access attempts on a set, or
+ # getitem sources.
+ if name == "add" and args and self.mutable_local:
+ assert not kwargs
+ item = args[0]
+ result = SetVariable(
+ tx,
+ self._add(tx, item),
+ mutable_local=self.mutable_local,
+ regen_guards=False,
+ **options,
+ )
+ tx.replace_all(self, result)
+ return ConstantVariable(None)
+ elif name == "pop" and self.mutable_local:
+ assert not kwargs
+ assert not args
+ items = list(self.items)
+ result = items.pop()
+ tx.replace_all(
+ self,
+ SetVariable(tx, items, regen_guards=False, **options),
+ )
+ return result
+ elif name == "__len__":
+ return ConstantVariable(len(self.items)).add_options(options)
+ else:
+ return super().call_method(tx, name, args, kwargs)
+
+ def getitem_const(self, arg: VariableTracker):
+ raise RuntimeError("Illegal to getitem on a set")