During export, generate Python TENSOR_MATCH guards (#94970)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94970
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 087141a..6556fdf 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -2365,7 +2365,6 @@
         self.assertIs(x_ref(), None)
 
     def test_release_module_memory(self):
-
         mod = torch.nn.Linear(10, 10)
         x = torch.rand([10, 10])
         mod_weight_ref = weakref.ref(mod.weight)
@@ -2711,7 +2710,6 @@
                 self.names = []
 
             def forward(self, idx, targets=None):
-
                 b, t = idx.size()
                 assert (
                     t <= self.block_size
@@ -3832,7 +3830,6 @@
         self.assertTrue(same(ref, res))
 
     def test_disable_flag(self):
-
         cnt = torch._dynamo.testing.CompileCounter()
 
         with patch.dict(os.environ, {"TORCH_COMPILE_DISABLE": "1"}):
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
index 41e5b7a..aa64225 100644
--- a/torch/_dynamo/guards.py
+++ b/torch/_dynamo/guards.py
@@ -117,6 +117,7 @@
         # tensor match guards make sure we actually have tensors)
         self.shape_env_code: List[str] = []
 
+        # [Note - On Eager Tensor Guards]
         # Most of the time, we generate Python code in a guard to directly
         # check various properties.  However, tensors are a bit special;
         # it is too slow to check their properties one-by-one in Python.
@@ -131,7 +132,6 @@
         self.tensor_check_names: List[str] = []
         self.tensor_check_examples: List[torch.Tensor] = []
 
-        self.tensor_check_ids: Dict[str, int] = {}
         self.check_fn_manager: CheckFunctionManager = check_fn_manager
 
     # Warning: use this with care!  This lets you access what the current
@@ -429,23 +429,49 @@
             value = self.get(guard.name)
             assert isinstance(value, torch.Tensor)
             tensor_name = self.arg_ref(guard)
-            self.tensor_check_names.append(tensor_name)
-            self.tensor_check_examples.append(value)
+            # [Note - On Export Tensor Guards]
+            #
+            # In eager mode, tensor guards are evaluated through C++, in guards.cpp
+            # see [Note - On Eager Tensor Guards] for more info.
+            #
+            # In export mode, we instead maintain parallel logic between C++ and python
+            # here, with an exception of checking the dispatch key - with the idea that a dispatch key
+            # is an entirely runtime notion that would make no sense to keep in an exported graph.
+            #
+            # Now, this idea is okay, but to paraphrase @ezyang, this mental model is sufficient for now, although
+            # not entirely true.
+            # For example, suppose one of the input tensors had the negative dispatch key.
+            # You should end up with a graph that is specialized for tensors that have a negative dispatch key.
+            # If you allow a Tensor that does NOT have this bit set, you will accidentally run it "as if" it were negated.
+            # Now, negative key only shows up for complex numbers, and most likely, the exported to target doesn't
+            # support this feature at all, but the point stands that :some: tensor state only shows up on dispatch key.
+            # TODO(voz): Either populate a dispatch_key check into the guards, or error on users passing in an unsupported
+            # subset of keys during export.
+            #
+            # The list of tensor fields and calls we care about can be found in `terms` below.
+            # TODO(voz): We are missing storage offset in all our tensor guards?
+            if self.check_fn_manager.output_graph.export:
+                self.TYPE_MATCH(guard)
+                code = []
+                terms = [
+                    "dtype",
+                    "device.type",
+                    "device.index",
+                    "requires_grad",
+                    "ndimension()",
+                ]
+                if not config.dynamic_shapes:
+                    terms.append("stride()")
+                    # We need to do this to avoid the torch.Size type in guards
+                    code.append(f"{tensor_name}.shape == {tuple(value.shape)}")
 
-            # STOP - DO NOT USE id_ref FOR TENSORS - TENSOR INVALIDATION RULES DIFFER
-            self.tensor_check_ids[tensor_name] = id(value)
-
-            # Note: Guard code produced for tensor_match is a little different.
-            # We accumulate tensor names, then do a single install of `___check_tensors`.
-            # See _guards.cpp and TensorGuard for more information.
-            # TODO(voz): Add tensor matching code to export
-            # Note: this is a bit of a special case, and so does not use _produce_guard_code
-            guard.set_export_info(
-                "TENSOR_MATCH",
-                weakref.ref(type(value)),
-                None,
-                weakref.ref(value),
-            )
+                for term in terms:
+                    real_value = self.get(tensor_name + "." + term)
+                    code.append(f"{tensor_name}.{term} == {real_value}")
+                self._produce_guard_code(guard, code)
+            else:
+                self.tensor_check_names.append(tensor_name)
+                self.tensor_check_examples.append(value)
 
     # A util that appends guarded code, or, in the case of export, adds data onto guards
     def _produce_guard_code(
@@ -589,12 +615,12 @@
             local_builder.tensor_check_names + global_builder.tensor_check_names
         )
 
-        tensor_check_ids = local_builder.tensor_check_ids.copy()
-        tensor_check_ids.update(global_builder.tensor_check_ids)
-
         check_tensors_fn = None
         check_tensors_verbose_fn = None
         if tensor_check_names:
+            assert (
+                not self.output_graph.export
+            ), "Illegal to set tensor_check_names in export."
             tensor_check_examples = (
                 local_builder.tensor_check_examples
                 + global_builder.tensor_check_examples
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index 92641ab..c622848 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -138,7 +138,6 @@
         return clone_inputs(self.original_example_inputs)
 
     def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
-
         self.restore = checkpoint_params(gm)
         self.gm = gm
         copy_gm = copy.deepcopy(self.gm)
@@ -186,6 +185,7 @@
         super().__init__()
         self.graph = torch.fx.Graph()
         self.graphargs: List[GraphArg] = []
+        self.export = export
         # In export mode, we force the shape_env to strictly disallow any constraining
         # of the user marked dynamic dims
         fake_mode = torch._subclasses.FakeTensorMode(
@@ -550,7 +550,6 @@
             and len(set(stack_values)) == len(stack_values)
             and self.side_effects.is_empty()
         ):
-
             # optimization to generate better code in a common case
             self.add_output_instructions(
                 self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp
index 5ff74bb..bf20837 100644
--- a/torch/csrc/dynamo/guards.cpp
+++ b/torch/csrc/dynamo/guards.cpp
@@ -44,6 +44,8 @@
     }
   }
 
+  // See note in guards.py [Note - On Export Tensor Guards]
+  // Logic parallel to here must be maintained in python
   bool check(const LocalState& state, const at::Tensor& v) {
     if (dispatch_key_ != state.apply(v.key_set()).raw_repr() ||
         dtype_ != v.dtype().toScalarType() ||