During export, generate Python TENSOR_MATCH guards (#94970)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94970
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 087141a..6556fdf 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -2365,7 +2365,6 @@
self.assertIs(x_ref(), None)
def test_release_module_memory(self):
-
mod = torch.nn.Linear(10, 10)
x = torch.rand([10, 10])
mod_weight_ref = weakref.ref(mod.weight)
@@ -2711,7 +2710,6 @@
self.names = []
def forward(self, idx, targets=None):
-
b, t = idx.size()
assert (
t <= self.block_size
@@ -3832,7 +3830,6 @@
self.assertTrue(same(ref, res))
def test_disable_flag(self):
-
cnt = torch._dynamo.testing.CompileCounter()
with patch.dict(os.environ, {"TORCH_COMPILE_DISABLE": "1"}):
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
index 41e5b7a..aa64225 100644
--- a/torch/_dynamo/guards.py
+++ b/torch/_dynamo/guards.py
@@ -117,6 +117,7 @@
# tensor match guards make sure we actually have tensors)
self.shape_env_code: List[str] = []
+ # [Note - On Eager Tensor Guards]
# Most of the time, we generate Python code in a guard to directly
# check various properties. However, tensors are a bit special;
# it is too slow to check their properties one-by-one in Python.
@@ -131,7 +132,6 @@
self.tensor_check_names: List[str] = []
self.tensor_check_examples: List[torch.Tensor] = []
- self.tensor_check_ids: Dict[str, int] = {}
self.check_fn_manager: CheckFunctionManager = check_fn_manager
# Warning: use this with care! This lets you access what the current
@@ -429,23 +429,49 @@
value = self.get(guard.name)
assert isinstance(value, torch.Tensor)
tensor_name = self.arg_ref(guard)
- self.tensor_check_names.append(tensor_name)
- self.tensor_check_examples.append(value)
+ # [Note - On Export Tensor Guards]
+ #
+ # In eager mode, tensor guards are evaluated through C++, in guards.cpp
+ # see [Note - On Eager Tensor Guards] for more info.
+ #
+ # In export mode, we instead maintain parallel logic between C++ and python
+ # here, with an exception of checking the dispatch key - with the idea that a dispatch key
+ # is an entirely runtime notion that would make no sense to keep in an exported graph.
+ #
+ # Now, this idea is okay, but to paraphrase @ezyang, this mental model is sufficient for now, although
+ # not entirely true.
+ # For example, suppose one of the input tensors had the negative dispatch key.
+ # You should end up with a graph that is specialized for tensors that have a negative dispatch key.
+ # If you allow a Tensor that does NOT have this bit set, you will accidentally run it "as if" it were negated.
+ # Now, negative key only shows up for complex numbers, and most likely, the exported to target doesn't
+ # support this feature at all, but the point stands that :some: tensor state only shows up on dispatch key.
+ # TODO(voz): Either populate a dispatch_key check into the guards, or error on users passing in an unsupported
+ # subset of keys during export.
+ #
+ # The list of tensor fields and calls we care about can be found in `terms` below.
+ # TODO(voz): We are missing storage offset in all our tensor guards?
+ if self.check_fn_manager.output_graph.export:
+ self.TYPE_MATCH(guard)
+ code = []
+ terms = [
+ "dtype",
+ "device.type",
+ "device.index",
+ "requires_grad",
+ "ndimension()",
+ ]
+ if not config.dynamic_shapes:
+ terms.append("stride()")
+ # We need to do this to avoid the torch.Size type in guards
+ code.append(f"{tensor_name}.shape == {tuple(value.shape)}")
- # STOP - DO NOT USE id_ref FOR TENSORS - TENSOR INVALIDATION RULES DIFFER
- self.tensor_check_ids[tensor_name] = id(value)
-
- # Note: Guard code produced for tensor_match is a little different.
- # We accumulate tensor names, then do a single install of `___check_tensors`.
- # See _guards.cpp and TensorGuard for more information.
- # TODO(voz): Add tensor matching code to export
- # Note: this is a bit of a special case, and so does not use _produce_guard_code
- guard.set_export_info(
- "TENSOR_MATCH",
- weakref.ref(type(value)),
- None,
- weakref.ref(value),
- )
+ for term in terms:
+ real_value = self.get(tensor_name + "." + term)
+ code.append(f"{tensor_name}.{term} == {real_value}")
+ self._produce_guard_code(guard, code)
+ else:
+ self.tensor_check_names.append(tensor_name)
+ self.tensor_check_examples.append(value)
# A util that appends guarded code, or, in the case of export, adds data onto guards
def _produce_guard_code(
@@ -589,12 +615,12 @@
local_builder.tensor_check_names + global_builder.tensor_check_names
)
- tensor_check_ids = local_builder.tensor_check_ids.copy()
- tensor_check_ids.update(global_builder.tensor_check_ids)
-
check_tensors_fn = None
check_tensors_verbose_fn = None
if tensor_check_names:
+ assert (
+ not self.output_graph.export
+ ), "Illegal to set tensor_check_names in export."
tensor_check_examples = (
local_builder.tensor_check_examples
+ global_builder.tensor_check_examples
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index 92641ab..c622848 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -138,7 +138,6 @@
return clone_inputs(self.original_example_inputs)
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
-
self.restore = checkpoint_params(gm)
self.gm = gm
copy_gm = copy.deepcopy(self.gm)
@@ -186,6 +185,7 @@
super().__init__()
self.graph = torch.fx.Graph()
self.graphargs: List[GraphArg] = []
+ self.export = export
# In export mode, we force the shape_env to strictly disallow any constraining
# of the user marked dynamic dims
fake_mode = torch._subclasses.FakeTensorMode(
@@ -550,7 +550,6 @@
and len(set(stack_values)) == len(stack_values)
and self.side_effects.is_empty()
):
-
# optimization to generate better code in a common case
self.add_output_instructions(
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp
index 5ff74bb..bf20837 100644
--- a/torch/csrc/dynamo/guards.cpp
+++ b/torch/csrc/dynamo/guards.cpp
@@ -44,6 +44,8 @@
}
}
+ // See note in guards.py [Note - On Export Tensor Guards]
+ // Logic parallel to here must be maintained in python
bool check(const LocalState& state, const at::Tensor& v) {
if (dispatch_key_ != state.apply(v.key_set()).raw_repr() ||
dtype_ != v.dtype().toScalarType() ||